Skip to content

Commit 2fa81d7

Browse files
committedJul 17, 2017
First Implementation
1 parent 71d2ed9 commit 2fa81d7

29 files changed

+811
-0
lines changed
 

‎LICENSE.md

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
MIT License
2+
3+
Copyright (c) 2017 Albert Berenguel
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
1. The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
2. Original authors' names are not deleted.
15+
3. The authors' names are not used to endorse or promote products derived
16+
from this software
17+
18+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24+
SOFTWARE.

‎data/DistanceNetwork/input_image.npy

8.08 KB
Binary file not shown.

‎data/DistanceNetwork/similarities.npy

2.58 KB
Binary file not shown.

‎data/DistanceNetwork/support_set.npy

160 KB
Binary file not shown.

‎data/gen_encode.npy

8.08 KB
Binary file not shown.

‎data/lstm/encoded_images_after.npy

8.08 KB
Binary file not shown.

‎data/lstm/encoded_images_before.npy

168 KB
Binary file not shown.

‎data/lstm/output_state_bw.npy

168 KB
Binary file not shown.

‎data/lstm/output_state_fw.npy

8.08 KB
Binary file not shown.

‎data/preds.npy

2.58 KB
Binary file not shown.

‎data/similarities.npy

2.58 KB
Binary file not shown.

‎data/softmax_similarities.npy

2.58 KB
Binary file not shown.

‎data/support_set_y.npy

50.1 KB
Binary file not shown.

‎data/target_image.npy

98.1 KB
Binary file not shown.

‎datasets/__init__.py

Whitespace-only changes.

‎datasets/__init__.pyc

107 Bytes
Binary file not shown.

‎datasets/omniglot.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2+
## Created by: Albert Berenguel
3+
## Computer Vision Center (CVC). Universitat Autonoma de Barcelona
4+
## Email: aberenguel@cvc.uab.es
5+
## Copyright (c) 2017
6+
##
7+
## This source code is licensed under the MIT-style license found in the
8+
## LICENSE file in the root directory of this source tree
9+
##
10+
## Arcknowledgments:
11+
## https://github.com/ludc. Using some parts of his Omniglot code.
12+
## https://github.com/AntreasAntoniou. Using some parts of his Omniglot code.
13+
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
14+
15+
from __future__ import print_function
16+
import torch.utils.data as data
17+
import os
18+
import os.path
19+
import errno
20+
21+
class OMNIGLOT(data.Dataset):
22+
urls = [
23+
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
24+
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
25+
]
26+
raw_folder = 'raw'
27+
processed_folder = 'processed'
28+
training_file = 'training.pt'
29+
test_file = 'test.pt'
30+
31+
'''
32+
The items are (filename,category). The index of all the categories can be found in self.idx_classes
33+
Args:
34+
- root: the directory where the dataset will be stored
35+
- transform: how to transform the input
36+
- target_transform: how to transform the target
37+
- download: need to download the dataset
38+
'''
39+
def __init__(self, root, transform=None, target_transform=None, download=False):
40+
self.root = root
41+
self.transform = transform
42+
self.target_transform = target_transform
43+
44+
if download:
45+
self.download()
46+
47+
if not self._check_exists():
48+
raise RuntimeError('Dataset not found.'
49+
+ ' You can use download=True to download it')
50+
51+
self.all_items=find_classes(os.path.join(self.root, self.processed_folder))
52+
self.idx_classes=index_classes(self.all_items)
53+
54+
def __getitem__(self, index):
55+
filename=self.all_items[index][0]
56+
img=str.join('/',[self.all_items[index][2],filename])
57+
58+
target=self.idx_classes[self.all_items[index][1]]
59+
if self.transform is not None:
60+
img = self.transform(img)
61+
if self.target_transform is not None:
62+
target = self.target_transform(target)
63+
64+
return img,target
65+
66+
def __len__(self):
67+
return len(self.all_items)
68+
69+
def _check_exists(self):
70+
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
71+
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
72+
73+
def download(self):
74+
from six.moves import urllib
75+
import zipfile
76+
77+
if self._check_exists():
78+
return
79+
80+
# download files
81+
try:
82+
os.makedirs(os.path.join(self.root, self.raw_folder))
83+
os.makedirs(os.path.join(self.root, self.processed_folder))
84+
except OSError as e:
85+
if e.errno == errno.EEXIST:
86+
pass
87+
else:
88+
raise
89+
90+
for url in self.urls:
91+
print('== Downloading ' + url)
92+
data = urllib.request.urlopen(url)
93+
filename = url.rpartition('/')[2]
94+
file_path = os.path.join(self.root, self.raw_folder, filename)
95+
with open(file_path, 'wb') as f:
96+
f.write(data.read())
97+
file_processed = os.path.join(self.root, self.processed_folder)
98+
print("== Unzip from "+file_path+" to "+file_processed)
99+
zip_ref = zipfile.ZipFile(file_path, 'r')
100+
zip_ref.extractall(file_processed)
101+
zip_ref.close()
102+
print("Download finished.")
103+
104+
def find_classes(root_dir):
105+
retour=[]
106+
for (root,dirs,files) in os.walk(root_dir):
107+
for f in files:
108+
if (f.endswith("png")):
109+
r=root.split('/')
110+
lr=len(r)
111+
retour.append((f,r[lr-2]+"/"+r[lr-1],root))
112+
print("== Found %d items "%len(retour))
113+
return retour
114+
115+
def index_classes(items):
116+
idx={}
117+
for i in items:
118+
if (not i[1] in idx):
119+
idx[i[1]]=len(idx)
120+
print("== Found %d classes"% len(idx))
121+
return idx

‎datasets/omniglot.pyc

3.83 KB
Binary file not shown.

‎datasets/omniglotNShot.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from datasets import omniglot
2+
import torchvision.transforms as transforms
3+
from PIL import Image
4+
from option import Options
5+
import os.path
6+
7+
import numpy as np
8+
np.random.seed(2191) # for reproducibility
9+
10+
# LAMBDA FUNCTIONS
11+
filenameToPILImage = lambda x: Image.open(x).convert('L')
12+
PiLImageResize = lambda x: x.resize((28,28))
13+
np_reshape = lambda x: np.reshape(x, (28, 28, 1))
14+
15+
class OmniglotNShotDataset():
16+
def __init__(self, batch_size = 100, classes_per_set=10, samples_per_class=1):
17+
18+
args = Options().parse()
19+
20+
if not os.path.isfile(os.path.join(args.dataroot,'data.npy')):
21+
self.x = omniglot.OMNIGLOT(args.dataroot, download=True,
22+
transform=transforms.Compose([filenameToPILImage,
23+
PiLImageResize,
24+
np_reshape]))
25+
#transforms.ToTensor()]))
26+
27+
"""
28+
# Convert to the format of AntreasAntoniou. Format [nClasses,nCharacters,28,28,1]
29+
"""
30+
temp = dict()
31+
for (img, label) in self.x:
32+
if label in temp:
33+
temp[label].append(img)
34+
else:
35+
temp[label]=[img]
36+
self.x = [] # Free memory
37+
38+
for classes in temp.keys():
39+
self.x.append(np.array(temp[temp.keys()[classes]]))
40+
self.x = np.array(self.x)
41+
temp = [] # Free memory
42+
np.save(os.path.join(args.dataroot,'data.npy'),self.x)
43+
else:
44+
self.x = np.load(os.path.join(args.dataroot,'data.npy'))
45+
46+
"""
47+
Constructs an N-Shot omniglot Dataset
48+
:param batch_size: Experiment batch_size
49+
:param classes_per_set: Integer indicating the number of classes per set
50+
:param samples_per_class: Integer indicating samples per class
51+
e.g. For a 20-way, 1-shot learning task, use classes_per_set=20 and samples_per_class=1
52+
For a 5-way, 10-shot learning task, use classes_per_set=5 and samples_per_class=10
53+
"""
54+
55+
shuffle_classes = np.arange(self.x.shape[0])
56+
np.random.shuffle(shuffle_classes)
57+
self.x = self.x[shuffle_classes]
58+
self.x_train, self.x_test, self.x_val = self.x[:1200], self.x[1200:1500], self.x[1500:]
59+
self.normalization()
60+
61+
self.batch_size = batch_size
62+
self.n_classes = self.x.shape[0]
63+
self.classes_per_set = classes_per_set
64+
self.samples_per_class = samples_per_class
65+
66+
self.indexes = {"train": 0, "val": 0, "test": 0}
67+
self.datasets = {"train": self.x_train, "val": self.x_val, "test": self.x_test} #original data cached
68+
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), #current epoch data cached
69+
"val": self.load_data_cache(self.datasets["val"]),
70+
"test": self.load_data_cache(self.datasets["test"])}
71+
72+
def normalization(self):
73+
"""
74+
Normalizes our data, to have a mean of 0 and sdt of 1
75+
"""
76+
self.mean = np.mean(self.x_train)
77+
self.std = np.std(self.x_train)
78+
self.max = np.max(self.x_train)
79+
self.min = np.min(self.x_train)
80+
print("train_shape", self.x_train.shape, "test_shape", self.x_test.shape, "val_shape", self.x_val.shape)
81+
print("before_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
82+
self.x_train = (self.x_train - self.mean) / self.std
83+
self.x_val = (self.x_val - self.mean) / self.std
84+
self.x_test = (self.x_test - self.mean) / self.std
85+
self.mean = np.mean(self.x_train)
86+
self.std = np.std(self.x_train)
87+
self.max = np.max(self.x_train)
88+
self.min = np.min(self.x_train)
89+
print("after_normalization", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
90+
91+
def load_data_cache(self, data_pack):
92+
"""
93+
Collects 1000 batches data for N-shot learning
94+
:param data_pack: Data pack to use (any one of train, val, test)
95+
:return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
96+
"""
97+
n_samples = self.samples_per_class * self.classes_per_set
98+
data_cache = []
99+
for sample in range(1000):
100+
support_set_x = np.zeros((self.batch_size, n_samples, 28, 28, 1))
101+
support_set_y = np.zeros((self.batch_size, n_samples))
102+
target_x = np.zeros((self.batch_size, 28, 28, 1), dtype=np.int)
103+
target_y = np.zeros((self.batch_size,), dtype=np.int)
104+
for i in range(self.batch_size):
105+
ind = 0
106+
pinds = np.random.permutation(n_samples)
107+
classes = np.random.choice(data_pack.shape[0], self.classes_per_set, False)
108+
x_hat_class = np.random.randint(self.classes_per_set)
109+
for j, cur_class in enumerate(classes): # each class
110+
example_inds = np.random.choice(data_pack.shape[1], self.samples_per_class, False)
111+
for eind in example_inds:
112+
support_set_x[i, pinds[ind], :, :, :] = data_pack[cur_class][eind]
113+
support_set_y[i, pinds[ind]] = j
114+
ind += 1
115+
if j == x_hat_class:
116+
target_x[i, :, :, :] = data_pack[cur_class][np.random.choice(data_pack.shape[1])]
117+
target_y[i] = j
118+
119+
data_cache.append([support_set_x, support_set_y, target_x, target_y])
120+
return data_cache
121+
122+
def get_batch(self, dataset_name):
123+
"""
124+
Gets next batch from the dataset with name.
125+
:param dataset_name: The name of the dataset (one of "train", "val", "test")
126+
:return:
127+
"""
128+
if self.indexes[dataset_name] >= len(self.datasets_cache[dataset_name]):
129+
self.indexes[dataset_name] = 0
130+
self.datasets_cache[dataset_name] = self.load_data_cache(self.datasets[dataset_name])
131+
next_batch = self.datasets_cache[dataset_name][self.indexes[dataset_name]]
132+
self.indexes[dataset_name] += 1
133+
x_support_set, y_support_set, x_target, y_target = next_batch
134+
return x_support_set, y_support_set, x_target, y_target
135+
136+
def get_train_batch(self):
137+
138+
"""
139+
Get next training batch
140+
:return: Next training batch
141+
"""
142+
return self.get_batch("train")
143+
144+
def get_test_batch(self):
145+
146+
"""
147+
Get next test batch
148+
:return: Next test_batch
149+
"""
150+
return self.get_batch("test")
151+
152+
def get_val_batch(self):
153+
154+
"""
155+
Get next val batch
156+
:return: Next val batch
157+
"""
158+
return self.get_batch("val")

‎datasets/omniglotNShot.pyc

5.71 KB
Binary file not shown.

‎main.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2+
## Created by: Albert Berenguel
3+
## Computer Vision Center (CVC). Universitat Autonoma de Barcelona
4+
## Email: aberenguel@cvc.uab.es
5+
## Copyright (c) 2017
6+
##
7+
## This source code is licensed under the MIT-style license found in the
8+
## LICENSE file in the root directory of this source tree
9+
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10+
11+
from datasets import omniglotNShot
12+
from option import Options
13+
14+
#Dummy test
15+
import torch
16+
from torch.autograd import Variable
17+
import torch.nn as nn
18+
19+
#args = Options().parse()
20+
#a = omniglotNShot.OmniglotNShotDataset()
21+
22+
from models.BidirectionalLSTM import BidirectionalLSTM
23+
24+
'''
25+
#Function to make dummy data.
26+
def datagen(batch_size, seq_length, vector_dim):
27+
return torch.rand(seq_length, batch_size, vector_dim)
28+
29+
samples = 100000
30+
batch_size = 32
31+
sequence_len = 20
32+
vector_dim = 64
33+
layer_sizes = [100, 100, 100]
34+
35+
lstm = BidirectionalLSTM(layer_sizes = layer_sizes, batch_size = batch_size, vector_dim = 64).cuda()
36+
37+
for sample in range(samples):
38+
39+
input = Variable(datagen(batch_size, sequence_len, vector_dim).cuda(), requires_grad = True)
40+
hidden, output = lstm(input)
41+
42+
43+
b = 0
44+
'''
45+
46+

0 commit comments

Comments
 (0)