1
1
import torch
2
+ import torch .backends .cudnn as cudnn
2
3
import tqdm
3
4
from models .MatchingNetwork import MatchingNetwork
4
5
from torch .autograd import Variable
@@ -38,7 +39,11 @@ def build_experiment(self, batch_size, classes_per_set, samples_per_class, chann
38
39
self .current_lr = 1e-03
39
40
self .lr_decay = 1e-6
40
41
self .wd = 1e-4
41
- self .matchingNet .cuda ()
42
+ self .isCudaAvailable = torch .cuda .is_available ()
43
+ if self .isCudaAvailable :
44
+ cudnn .benchmark = True
45
+ torch .cuda .manual_seed_all (0 )
46
+ self .matchingNet .cuda ()
42
47
43
48
def run_training_epoch (self , total_train_batches ):
44
49
"""
@@ -75,7 +80,12 @@ def run_training_epoch(self, total_train_batches):
75
80
x_support_set = x_support_set .view (size [0 ],size [1 ],size [4 ],size [2 ],size [3 ])
76
81
size = x_target .size ()
77
82
x_target = x_target .view (size [0 ], size [3 ], size [1 ], size [2 ])
78
- acc , c_loss_value = self .matchingNet (x_support_set .cuda (), y_support_set_one_hot .cuda (), x_target .cuda (), y_target .cuda ())
83
+ if self .isCudaAvailable :
84
+ acc , c_loss_value = self .matchingNet (x_support_set .cuda (), y_support_set_one_hot .cuda (),
85
+ x_target .cuda (), y_target .cuda ())
86
+ else :
87
+ acc , c_loss_value = self .matchingNet (x_support_set , y_support_set_one_hot ,
88
+ x_target , y_target )
79
89
80
90
# Before the backward pass, use the optimizer object to zero all of the
81
91
# gradients for the variables it will update (which are the learnable weights
@@ -122,10 +132,10 @@ def run_validation_epoch(self, total_val_batches):
122
132
x_support_set , y_support_set , x_target , y_target = \
123
133
self .data .get_batch (str_type = 'val' , rotate_flag = False )
124
134
125
- x_support_set = Variable (torch .from_numpy (x_support_set ), requires_grad = False ).float ()
126
- y_support_set = Variable (torch .from_numpy (y_support_set ), requires_grad = False ).long ()
127
- x_target = Variable (torch .from_numpy (x_target ), requires_grad = False ).float ()
128
- y_target = Variable (torch .from_numpy (y_target ), requires_grad = False ).long ()
135
+ x_support_set = Variable (torch .from_numpy (x_support_set ), volatile = True ).float ()
136
+ y_support_set = Variable (torch .from_numpy (y_support_set ), volatile = True ).long ()
137
+ x_target = Variable (torch .from_numpy (x_target ), volatile = True ).float ()
138
+ y_target = Variable (torch .from_numpy (y_target ), volatile = True ).long ()
129
139
130
140
# y_support_set: Add extra dimension for the one_hot
131
141
y_support_set = torch .unsqueeze (y_support_set , 2 )
@@ -141,8 +151,12 @@ def run_validation_epoch(self, total_val_batches):
141
151
x_support_set = x_support_set .view (size [0 ], size [1 ], size [4 ], size [2 ], size [3 ])
142
152
size = x_target .size ()
143
153
x_target = x_target .view (size [0 ], size [3 ], size [1 ], size [2 ])
144
- acc , c_loss_value = self .matchingNet (x_support_set .cuda (), y_support_set_one_hot .cuda (),
145
- x_target .cuda (), y_target .cuda ())
154
+ if self .isCudaAvailable :
155
+ acc , c_loss_value = self .matchingNet (x_support_set .cuda (), y_support_set_one_hot .cuda (),
156
+ x_target .cuda (), y_target .cuda ())
157
+ else :
158
+ acc , c_loss_value = self .matchingNet (x_support_set , y_support_set_one_hot ,
159
+ x_target , y_target )
146
160
147
161
iter_out = "val_loss: {}, val_accuracy: {}" .format (c_loss_value .data [0 ], acc .data [0 ])
148
162
pbar .set_description (iter_out )
@@ -170,10 +184,10 @@ def run_testing_epoch(self, total_test_batches):
170
184
x_support_set , y_support_set , x_target , y_target = \
171
185
self .data .get_batch (str_type = 'test' , rotate_flag = False )
172
186
173
- x_support_set = Variable (torch .from_numpy (x_support_set ), requires_grad = False ).float ()
174
- y_support_set = Variable (torch .from_numpy (y_support_set ), requires_grad = False ).long ()
175
- x_target = Variable (torch .from_numpy (x_target ), requires_grad = False ).float ()
176
- y_target = Variable (torch .from_numpy (y_target ), requires_grad = False ).long ()
187
+ x_support_set = Variable (torch .from_numpy (x_support_set ), volatile = True ).float ()
188
+ y_support_set = Variable (torch .from_numpy (y_support_set ), volatile = True ).long ()
189
+ x_target = Variable (torch .from_numpy (x_target ), volatile = True ).float ()
190
+ y_target = Variable (torch .from_numpy (y_target ), volatile = True ).long ()
177
191
178
192
# y_support_set: Add extra dimension for the one_hot
179
193
y_support_set = torch .unsqueeze (y_support_set , 2 )
@@ -189,8 +203,12 @@ def run_testing_epoch(self, total_test_batches):
189
203
x_support_set = x_support_set .view (size [0 ], size [1 ], size [4 ], size [2 ], size [3 ])
190
204
size = x_target .size ()
191
205
x_target = x_target .view (size [0 ], size [3 ], size [1 ], size [2 ])
192
- acc , c_loss_value = self .matchingNet (x_support_set .cuda (), y_support_set_one_hot .cuda (),
193
- x_target .cuda (), y_target .cuda ())
206
+ if self .isCudaAvailable :
207
+ acc , c_loss_value = self .matchingNet (x_support_set .cuda (), y_support_set_one_hot .cuda (),
208
+ x_target .cuda (), y_target .cuda ())
209
+ else :
210
+ acc , c_loss_value = self .matchingNet (x_support_set , y_support_set_one_hot ,
211
+ x_target , y_target )
194
212
195
213
iter_out = "test_loss: {}, test_accuracy: {}" .format (c_loss_value .data [0 ], acc .data [0 ])
196
214
pbar .set_description (iter_out )
0 commit comments