Skip to content

Commit 0eb4d4b

Browse files
committed
Added support for 5-shot
1 parent ebb9ba9 commit 0eb4d4b

File tree

4 files changed

+60
-40
lines changed

4 files changed

+60
-40
lines changed

datasets/miniImagenetOneShot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, dataroot = '/home/aberenguel/Dataset/miniImagenet', type = 't
3232
self.classes_per_set = classes_per_set
3333
self.samples_per_class = samples_per_class
3434
self.n_samples = self.samples_per_class * self.classes_per_set
35-
self.n_samplesNShot = 1 # Samples per meta-test. In this case 1 as is OneShot.
35+
self.n_samplesNShot = 5 # Samples per meta-test. In this case 1 as is OneShot.
3636
# Transformations to the image
3737
self.transform = transforms.Compose([filenameToPILImage,
3838
PiLImageResize,

datasets/omniglotNShot.py

+32-20
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ def normalization(self):
8484
self.max = np.max(self.x_train)
8585
self.min = np.min(self.x_train)
8686
print("train_shape", self.x_train.shape, "test_shape", self.x_test.shape, "val_shape", self.x_val.shape)
87-
print("before_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
87+
#print("before_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
8888
self.x_train = (self.x_train - self.mean) / self.std
8989
self.x_val = (self.x_val - self.mean) / self.std
9090
self.x_test = (self.x_test - self.mean) / self.std
91-
self.mean = np.mean(self.x_train)
92-
self.std = np.std(self.x_train)
93-
self.max = np.max(self.x_train)
94-
self.min = np.min(self.x_train)
95-
print("after_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
91+
#self.mean = np.mean(self.x_train)
92+
#self.std = np.std(self.x_train)
93+
#self.max = np.max(self.x_train)
94+
#self.min = np.min(self.x_train)
95+
#print("after_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
9696

9797
def load_data_cache(self, data_pack):
9898
"""
@@ -105,22 +105,34 @@ def load_data_cache(self, data_pack):
105105
for sample in range(1000):
106106
support_set_x = np.zeros((self.batch_size, n_samples, 28, 28, 1))
107107
support_set_y = np.zeros((self.batch_size, n_samples))
108-
target_x = np.zeros((self.batch_size, 28, 28, 1), dtype=np.int)
109-
target_y = np.zeros((self.batch_size,), dtype=np.int)
108+
target_x = np.zeros((self.batch_size, self.samples_per_class, 28, 28, 1), dtype=np.int)
109+
target_y = np.zeros((self.batch_size, self.samples_per_class), dtype=np.int)
110110
for i in range(self.batch_size):
111-
ind = 0
112111
pinds = np.random.permutation(n_samples)
113112
classes = np.random.choice(data_pack.shape[0], self.classes_per_set, False)
114-
x_hat_class = np.random.randint(self.classes_per_set)
113+
# select 1-shot or 5-shot classes for test with repetition
114+
x_hat_class = np.random.choice(classes, self.samples_per_class, True)
115+
pinds_test = np.random.permutation(self.samples_per_class)
116+
ind = 0
117+
ind_test = 0
115118
for j, cur_class in enumerate(classes): # each class
116-
example_inds = np.random.choice(data_pack.shape[1], self.samples_per_class, False)
117-
for eind in example_inds:
119+
if cur_class in x_hat_class:
120+
# Count number of times this class is inside the meta-test
121+
n_test_samples = np.sum(cur_class == x_hat_class)
122+
example_inds = np.random.choice(data_pack.shape[1], self.samples_per_class + n_test_samples, False)
123+
else:
124+
example_inds = np.random.choice(data_pack.shape[1], self.samples_per_class, False)
125+
126+
# meta-training
127+
for eind in example_inds[:self.samples_per_class]:
118128
support_set_x[i, pinds[ind], :, :, :] = data_pack[cur_class][eind]
119129
support_set_y[i, pinds[ind]] = j
120-
ind += 1
121-
if j == x_hat_class:
122-
target_x[i, :, :, :] = data_pack[cur_class][np.random.choice(data_pack.shape[1])]
123-
target_y[i] = j
130+
ind = ind + 1
131+
# meta-test
132+
for eind in example_inds[self.samples_per_class:]:
133+
target_x[i, pinds_test[ind_test], :, :, :] = data_pack[cur_class][eind]
134+
target_y[i, pinds_test[ind_test]] = j
135+
ind_test = ind_test + 1
124136

125137
data_cache.append([support_set_x, support_set_y, target_x, target_y])
126138
return data_cache
@@ -149,11 +161,11 @@ def get_batch(self,str_type, rotate_flag = False):
149161
if rotate_flag:
150162
k = int(np.random.uniform(low=0, high=4))
151163
# Iterate over the sequence. Extract batches.
152-
for i in np.arange(x_support_set.shape[1]):
153-
x_support_set[:,i,:,:,:] = self.__rotate_batch(x_support_set[:,i,:,:,:],k)
164+
for i in np.arange(x_support_set.shape[0]):
165+
x_support_set[i,:,:,:,:] = self.__rotate_batch(x_support_set[i,:,:,:,:],k)
154166
# Rotate all the batch of the target images
155-
x_target = self.__rotate_batch(x_target,k)
156-
167+
for i in np.arange(x_target.shape[0]):
168+
x_target[i,:,:,:,:] = self.__rotate_batch(x_target[i,:,:,:,:], k)
157169
return x_support_set, y_support_set, x_target, y_target
158170

159171

experiments/OneShotBuilder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def run_training_epoch(self, total_train_batches):
9090
size = x_support_set.size()
9191
x_support_set = x_support_set.view(size[0],size[1],size[4],size[2],size[3])
9292
size = x_target.size()
93-
x_target = x_target.view(size[0], size[3], size[1], size[2])
93+
x_target = x_target.view(size[0],size[1],size[4],size[2],size[3])
9494
if self.isCudaAvailable:
9595
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
9696
x_target.cuda(), y_target.cuda())
@@ -160,7 +160,7 @@ def run_validation_epoch(self, total_val_batches):
160160
size = x_support_set.size()
161161
x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
162162
size = x_target.size()
163-
x_target = x_target.view(size[0], size[3], size[1], size[2])
163+
x_target = x_target.view(size[0],size[1],size[4],size[2],size[3])
164164
if self.isCudaAvailable:
165165
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
166166
x_target.cuda(), y_target.cuda())
@@ -212,7 +212,7 @@ def run_testing_epoch(self, total_test_batches):
212212
size = x_support_set.size()
213213
x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
214214
size = x_target.size()
215-
x_target = x_target.view(size[0], size[3], size[1], size[2])
215+
x_target = x_target.view(size[0],size[1],size[4],size[2],size[3])
216216
if self.isCudaAvailable:
217217
acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
218218
x_target.cuda(), y_target.cuda())

models/MatchingNetwork.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,38 @@ def forward(self, support_set_images, support_set_labels_one_hot, target_image,
6262
# produce embeddings for support set images
6363
encoded_images = []
6464
for i in np.arange(support_set_images.size(1)):
65-
gen_encode = self.g(support_set_images[:,i,:,:])
65+
gen_encode = self.g(support_set_images[:,i,:,:,:])
6666
encoded_images.append(gen_encode)
6767

6868
# produce embeddings for target images
69-
gen_encode = self.g(target_image)
70-
encoded_images.append(gen_encode)
71-
outputs = torch.stack(encoded_images)
69+
for i in np.arange(target_image.size(1)):
70+
gen_encode = self.g(target_image[:,i,:,:,:])
71+
encoded_images.append(gen_encode)
72+
outputs = torch.stack(encoded_images)
73+
74+
if self.fce:
75+
outputs, hn, cn = self.lstm(outputs)
7276

73-
if self.fce:
74-
outputs, hn, cn = self.lstm(outputs)
77+
# get similarity between support set embeddings and target
78+
similarities = self.dn(support_set=outputs[:-1], input_image=outputs[-1])
79+
similarities = similarities.t()
7580

76-
# get similarity between support set embeddings and target
77-
similarities = self.dn(support_set=outputs[:-1], input_image=outputs[-1])
78-
similarities = similarities.t()
81+
# produce predictions for target probabilities
82+
preds = self.classify(similarities,support_set_y=support_set_labels_one_hot)
7983

80-
# produce predictions for target probabilities
81-
preds = self.classify(similarities,support_set_y=support_set_labels_one_hot)
84+
# calculate accuracy and crossentropy loss
85+
values, indices = preds.max(1)
86+
if i == 0:
87+
accuracy = torch.mean((indices.squeeze() == target_label[:,i]).float())
88+
crossentropy_loss = F.cross_entropy(preds, target_label[:,i].long())
89+
else:
90+
accuracy = accuracy + torch.mean((indices.squeeze() == target_label[:, i]).float())
91+
crossentropy_loss = accuracy + F.cross_entropy(preds, target_label[:, i].long())
8292

83-
# calculate accuracy and crossentropy loss
84-
values, indices = preds.max(1)
85-
accuracy = torch.mean((indices.squeeze() == target_label).float())
86-
crossentropy_loss = F.cross_entropy(preds, target_label.long())
93+
# delete the last target image encoding of encoded_images
94+
encoded_images.pop()
8795

88-
return accuracy, crossentropy_loss
96+
return accuracy/target_image.size(1), crossentropy_loss/target_image.size(1)
8997

9098

9199
class MatchingNetworkTest(unittest.TestCase):

0 commit comments

Comments
 (0)