@@ -84,15 +84,15 @@ def normalization(self):
84
84
self .max = np .max (self .x_train )
85
85
self .min = np .min (self .x_train )
86
86
print ("train_shape" , self .x_train .shape , "test_shape" , self .x_test .shape , "val_shape" , self .x_val .shape )
87
- print ("before_normalization" , "mean" , self .mean , "max" , self .max , "min" , self .min , "std" , self .std )
87
+ # print("before_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
88
88
self .x_train = (self .x_train - self .mean ) / self .std
89
89
self .x_val = (self .x_val - self .mean ) / self .std
90
90
self .x_test = (self .x_test - self .mean ) / self .std
91
- self .mean = np .mean (self .x_train )
92
- self .std = np .std (self .x_train )
93
- self .max = np .max (self .x_train )
94
- self .min = np .min (self .x_train )
95
- print ("after_normalization" , "mean" , self .mean , "max" , self .max , "min" , self .min , "std" , self .std )
91
+ # self.mean = np.mean(self.x_train)
92
+ # self.std = np.std(self.x_train)
93
+ # self.max = np.max(self.x_train)
94
+ # self.min = np.min(self.x_train)
95
+ # print("after_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
96
96
97
97
def load_data_cache (self , data_pack ):
98
98
"""
@@ -105,22 +105,34 @@ def load_data_cache(self, data_pack):
105
105
for sample in range (1000 ):
106
106
support_set_x = np .zeros ((self .batch_size , n_samples , 28 , 28 , 1 ))
107
107
support_set_y = np .zeros ((self .batch_size , n_samples ))
108
- target_x = np .zeros ((self .batch_size , 28 , 28 , 1 ), dtype = np .int )
109
- target_y = np .zeros ((self .batch_size ,), dtype = np .int )
108
+ target_x = np .zeros ((self .batch_size , self . samples_per_class , 28 , 28 , 1 ), dtype = np .int )
109
+ target_y = np .zeros ((self .batch_size , self . samples_per_class ), dtype = np .int )
110
110
for i in range (self .batch_size ):
111
- ind = 0
112
111
pinds = np .random .permutation (n_samples )
113
112
classes = np .random .choice (data_pack .shape [0 ], self .classes_per_set , False )
114
- x_hat_class = np .random .randint (self .classes_per_set )
113
+ # select 1-shot or 5-shot classes for test with repetition
114
+ x_hat_class = np .random .choice (classes , self .samples_per_class , True )
115
+ pinds_test = np .random .permutation (self .samples_per_class )
116
+ ind = 0
117
+ ind_test = 0
115
118
for j , cur_class in enumerate (classes ): # each class
116
- example_inds = np .random .choice (data_pack .shape [1 ], self .samples_per_class , False )
117
- for eind in example_inds :
119
+ if cur_class in x_hat_class :
120
+ # Count number of times this class is inside the meta-test
121
+ n_test_samples = np .sum (cur_class == x_hat_class )
122
+ example_inds = np .random .choice (data_pack .shape [1 ], self .samples_per_class + n_test_samples , False )
123
+ else :
124
+ example_inds = np .random .choice (data_pack .shape [1 ], self .samples_per_class , False )
125
+
126
+ # meta-training
127
+ for eind in example_inds [:self .samples_per_class ]:
118
128
support_set_x [i , pinds [ind ], :, :, :] = data_pack [cur_class ][eind ]
119
129
support_set_y [i , pinds [ind ]] = j
120
- ind += 1
121
- if j == x_hat_class :
122
- target_x [i , :, :, :] = data_pack [cur_class ][np .random .choice (data_pack .shape [1 ])]
123
- target_y [i ] = j
130
+ ind = ind + 1
131
+ # meta-test
132
+ for eind in example_inds [self .samples_per_class :]:
133
+ target_x [i , pinds_test [ind_test ], :, :, :] = data_pack [cur_class ][eind ]
134
+ target_y [i , pinds_test [ind_test ]] = j
135
+ ind_test = ind_test + 1
124
136
125
137
data_cache .append ([support_set_x , support_set_y , target_x , target_y ])
126
138
return data_cache
@@ -149,11 +161,11 @@ def get_batch(self,str_type, rotate_flag = False):
149
161
if rotate_flag :
150
162
k = int (np .random .uniform (low = 0 , high = 4 ))
151
163
# Iterate over the sequence. Extract batches.
152
- for i in np .arange (x_support_set .shape [1 ]):
153
- x_support_set [:, i ,:,:,:] = self .__rotate_batch (x_support_set [:, i ,:,:,:],k )
164
+ for i in np .arange (x_support_set .shape [0 ]):
165
+ x_support_set [i ,:,:,:,: ] = self .__rotate_batch (x_support_set [i ,: ,:,:,:],k )
154
166
# Rotate all the batch of the target images
155
- x_target = self . __rotate_batch (x_target , k )
156
-
167
+ for i in np . arange (x_target . shape [ 0 ]):
168
+ x_target [ i ,:,:,:,:] = self . __rotate_batch ( x_target [ i ,:,:,:,:], k )
157
169
return x_support_set , y_support_set , x_target , y_target
158
170
159
171
0 commit comments