1
+ from datasets import omniglot
2
+ import torchvision .transforms as transforms
3
+ from PIL import Image
4
+ from option import Options
5
+ import os .path
6
+
7
+ import numpy as np
8
+ np .random .seed (2191 ) # for reproducibility
9
+
10
+ # LAMBDA FUNCTIONS
11
+ filenameToPILImage = lambda x : Image .open (x ).convert ('L' )
12
+ PiLImageResize = lambda x : x .resize ((28 ,28 ))
13
+ np_reshape = lambda x : np .reshape (x , (28 , 28 , 1 ))
14
+
15
+ class OmniglotNShotDataset ():
16
+ def __init__ (self , batch_size = 100 , classes_per_set = 10 , samples_per_class = 1 ):
17
+
18
+ args = Options ().parse ()
19
+
20
+ if not os .path .isfile (os .path .join (args .dataroot ,'data.npy' )):
21
+ self .x = omniglot .OMNIGLOT (args .dataroot , download = True ,
22
+ transform = transforms .Compose ([filenameToPILImage ,
23
+ PiLImageResize ,
24
+ np_reshape ]))
25
+ #transforms.ToTensor()]))
26
+
27
+ """
28
+ # Convert to the format of AntreasAntoniou. Format [nClasses,nCharacters,28,28,1]
29
+ """
30
+ temp = dict ()
31
+ for (img , label ) in self .x :
32
+ if label in temp :
33
+ temp [label ].append (img )
34
+ else :
35
+ temp [label ]= [img ]
36
+ self .x = [] # Free memory
37
+
38
+ for classes in temp .keys ():
39
+ self .x .append (np .array (temp [temp .keys ()[classes ]]))
40
+ self .x = np .array (self .x )
41
+ temp = [] # Free memory
42
+ np .save (os .path .join (args .dataroot ,'data.npy' ),self .x )
43
+ else :
44
+ self .x = np .load (os .path .join (args .dataroot ,'data.npy' ))
45
+
46
+ """
47
+ Constructs an N-Shot omniglot Dataset
48
+ :param batch_size: Experiment batch_size
49
+ :param classes_per_set: Integer indicating the number of classes per set
50
+ :param samples_per_class: Integer indicating samples per class
51
+ e.g. For a 20-way, 1-shot learning task, use classes_per_set=20 and samples_per_class=1
52
+ For a 5-way, 10-shot learning task, use classes_per_set=5 and samples_per_class=10
53
+ """
54
+
55
+ shuffle_classes = np .arange (self .x .shape [0 ])
56
+ np .random .shuffle (shuffle_classes )
57
+ self .x = self .x [shuffle_classes ]
58
+ self .x_train , self .x_test , self .x_val = self .x [:1200 ], self .x [1200 :1500 ], self .x [1500 :]
59
+ self .normalization ()
60
+
61
+ self .batch_size = batch_size
62
+ self .n_classes = self .x .shape [0 ]
63
+ self .classes_per_set = classes_per_set
64
+ self .samples_per_class = samples_per_class
65
+
66
+ self .indexes = {"train" : 0 , "val" : 0 , "test" : 0 }
67
+ self .datasets = {"train" : self .x_train , "val" : self .x_val , "test" : self .x_test } #original data cached
68
+ self .datasets_cache = {"train" : self .load_data_cache (self .datasets ["train" ]), #current epoch data cached
69
+ "val" : self .load_data_cache (self .datasets ["val" ]),
70
+ "test" : self .load_data_cache (self .datasets ["test" ])}
71
+
72
+ def normalization (self ):
73
+ """
74
+ Normalizes our data, to have a mean of 0 and sdt of 1
75
+ """
76
+ self .mean = np .mean (self .x_train )
77
+ self .std = np .std (self .x_train )
78
+ self .max = np .max (self .x_train )
79
+ self .min = np .min (self .x_train )
80
+ print ("train_shape" , self .x_train .shape , "test_shape" , self .x_test .shape , "val_shape" , self .x_val .shape )
81
+ print ("before_normalization" , "mean" , self .mean , "max" , self .max , "min" , self .min , "std" , self .std )
82
+ self .x_train = (self .x_train - self .mean ) / self .std
83
+ self .x_val = (self .x_val - self .mean ) / self .std
84
+ self .x_test = (self .x_test - self .mean ) / self .std
85
+ self .mean = np .mean (self .x_train )
86
+ self .std = np .std (self .x_train )
87
+ self .max = np .max (self .x_train )
88
+ self .min = np .min (self .x_train )
89
+ print ("after_normalization" , "mean" , self .mean , "max" , self .max , "min" , self .min , "std" , self .std )
90
+
91
+ def load_data_cache (self , data_pack ):
92
+ """
93
+ Collects 1000 batches data for N-shot learning
94
+ :param data_pack: Data pack to use (any one of train, val, test)
95
+ :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
96
+ """
97
+ n_samples = self .samples_per_class * self .classes_per_set
98
+ data_cache = []
99
+ for sample in range (1000 ):
100
+ support_set_x = np .zeros ((self .batch_size , n_samples , 28 , 28 , 1 ))
101
+ support_set_y = np .zeros ((self .batch_size , n_samples ))
102
+ target_x = np .zeros ((self .batch_size , 28 , 28 , 1 ), dtype = np .int )
103
+ target_y = np .zeros ((self .batch_size ,), dtype = np .int )
104
+ for i in range (self .batch_size ):
105
+ ind = 0
106
+ pinds = np .random .permutation (n_samples )
107
+ classes = np .random .choice (data_pack .shape [0 ], self .classes_per_set , False )
108
+ x_hat_class = np .random .randint (self .classes_per_set )
109
+ for j , cur_class in enumerate (classes ): # each class
110
+ example_inds = np .random .choice (data_pack .shape [1 ], self .samples_per_class , False )
111
+ for eind in example_inds :
112
+ support_set_x [i , pinds [ind ], :, :, :] = data_pack [cur_class ][eind ]
113
+ support_set_y [i , pinds [ind ]] = j
114
+ ind += 1
115
+ if j == x_hat_class :
116
+ target_x [i , :, :, :] = data_pack [cur_class ][np .random .choice (data_pack .shape [1 ])]
117
+ target_y [i ] = j
118
+
119
+ data_cache .append ([support_set_x , support_set_y , target_x , target_y ])
120
+ return data_cache
121
+
122
+ def get_batch (self , dataset_name ):
123
+ """
124
+ Gets next batch from the dataset with name.
125
+ :param dataset_name: The name of the dataset (one of "train", "val", "test")
126
+ :return:
127
+ """
128
+ if self .indexes [dataset_name ] >= len (self .datasets_cache [dataset_name ]):
129
+ self .indexes [dataset_name ] = 0
130
+ self .datasets_cache [dataset_name ] = self .load_data_cache (self .datasets [dataset_name ])
131
+ next_batch = self .datasets_cache [dataset_name ][self .indexes [dataset_name ]]
132
+ self .indexes [dataset_name ] += 1
133
+ x_support_set , y_support_set , x_target , y_target = next_batch
134
+ return x_support_set , y_support_set , x_target , y_target
135
+
136
+ def get_train_batch (self ):
137
+
138
+ """
139
+ Get next training batch
140
+ :return: Next training batch
141
+ """
142
+ return self .get_batch ("train" )
143
+
144
+ def get_test_batch (self ):
145
+
146
+ """
147
+ Get next test batch
148
+ :return: Next test_batch
149
+ """
150
+ return self .get_batch ("test" )
151
+
152
+ def get_val_batch (self ):
153
+
154
+ """
155
+ Get next val batch
156
+ :return: Next val batch
157
+ """
158
+ return self .get_batch ("val" )
0 commit comments