Skip to content

Commit e30802d

Browse files
committed
CIFAR-10_experiments.py file was created. This file is just a python version of the current notebook. I this way is easyer to make commits on git.
1 parent 27a056e commit e30802d

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed

CIFAR-10_experiments.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# -*- coding: utf-8 -*-
2+
"""notebook.ipynb
3+
4+
Automatically generated by Colaboratory.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/13d6r8OZIrU9a_UycsEczdpfF7Exg42G9
8+
9+
<a href="https://colab.research.google.com/github/gan3sh500/mixmatch-pytorch/blob/master/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
10+
11+
This notebook tries to implement the MixMatch technique from the [paper](https://arxiv.org/pdf/1905.02249.pdf) MixMatch: A Holistic Approach to Semi-Supervised Learning and recreate their results on CIFAR10 with WideResnet28.
12+
13+
It depends on Pytorch, Numpy and imgaug. The WideResnet28 model code is taken from [meliketoy](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py)'s github repository. Hopefully I can train this on Colab with a Tesla T4. :)
14+
"""
15+
16+
!nvidia-smi
17+
18+
import torch
19+
import numpy as np
20+
import imgaug.augmenters as iaa
21+
22+
"""Now that we have the basic imports out of the way lets get to it.
23+
First we shall define the function to get augmented version of a given batch of images. The below function returns the function to do that.
24+
"""
25+
26+
def get_augmenter():
27+
seq = iaa.Sequential([
28+
iaa.Crop(px=(0, 16)),
29+
iaa.Fliplr(0.5),
30+
iaa.GaussianBlur(sigma=(0, 3.0))
31+
])
32+
def augment(images):
33+
return seq.augment(images.transpose(0, 2, 3, 1)).transpose(0, 2, 3, 1)
34+
return augment
35+
36+
"""Next we define the sharpening function to sharpen the prediction from the averaged prediction of all the unlabeled augmented images. It does the same thing as applying a temperature within the softmax function but to the probabilities."""
37+
38+
def sharpen(x, T):
39+
temp = x**(1/T)
40+
return temp / temp.sum(axis=1, keepdims=True)
41+
42+
"""A simple implementation of the [paper](https://arxiv.org/pdf/1710.09412.pdf) mixup: Beyond Empirical Risk Minimization used in this paper as well."""
43+
44+
def mixup(x1, x2, y1, y2, alpha):
45+
beta = np.random.beta(alpha, -alpha)
46+
x = beta * x1 + (1 - beta) * x2
47+
y = beta * y1 + (1 - beta) * y2
48+
return x, y
49+
50+
"""This covers Algorithm 1 from the paper."""
51+
52+
def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):
53+
xb = augment_fn(x)
54+
ub = [augment_fn(u) for _ in range(K)]
55+
qb = sharpen(sum(map(lambda i: model(i), ub)) / K, T)
56+
Ux = np.concatenate(ub, axis=0)
57+
Uy = np.concatenate([qb for _ in range(K)], axis=0)
58+
indices = np.random.shuffle(np.arange(len(xb) + len(Ux)))
59+
Wx = np.concatenate([Ux, xb], axis=0)[indices]
60+
Wy = np.concatenate([qb, y], axis=0)[indices]
61+
X, p = mixup(xb, Wx[:len(xb)], y, Wy[:len(xb)], alpha)
62+
U, q = mixup(Ux, Wx[len(xb):], Uy, Wy[len(xb):], alpha)
63+
return X, U, p, q
64+
65+
"""The combined loss for training from the paper."""
66+
67+
class MixMatchLoss(torch.nn.Module):
68+
def __init__(self, lambda_u=100):
69+
self.lambda_u = lambda_u
70+
self.xent = torch.nn.CrossEntropyLoss()
71+
self.mse = torch.nn.MSELoss()
72+
super(MixMatchLoss, self).__init__()
73+
74+
def forward(self, X, U, p, q, model):
75+
X_ = np.concatenate([X, U], axis=1)
76+
preds = model(X_)
77+
return self.xent(preds[:len(p)], p) + \
78+
self.lambda_u * self.mse(preds[len(p):], q)
79+
80+
"""Now that we have the MixMatch stuff done, we have a few things to do. Namely, define the WideResnet28 model, write the data and training code and write testing code.
81+
Let's start with the model. The below is just a copy paste mostly from the wide-resnet.pytorch repo by meliketoy.
82+
"""
83+
84+
def conv3x3(in_planes, out_planes, stride=1):
85+
return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
86+
bias=True)
87+
88+
"""Will need the below init function later before training."""
89+
90+
def conv_init(m):
91+
classname = m.__class__.__name__
92+
if classname.find('Conv') != -1:
93+
torch.nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
94+
torch.nn.init.constant(m.bias, 0)
95+
elif classname.find('BatchNorm') != -1:
96+
torch.nn.init.constant(m.weight, 1)
97+
torch.nn.init.constant(m.bias, 0)
98+
99+
"""The basic block for the WideResnet"""
100+
101+
class WideBasic(torch.nn.Module):
102+
def __init__(self, in_planes, planes, dropout_rate, stride=1):
103+
super(WideBasic, self).__init__()
104+
self.bn1 = torch.nn.BatchNorm2d(in_planes)
105+
self.bn2 = torch.nn.BatchNorm2d(planes)
106+
self.conv1 = torch.nn.Conv2d(in_planes, planes, kernel_size=3,
107+
padding=1, bias=True)
108+
self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,
109+
padding=1, bias=True)
110+
self.dropout = torch.nn.Dropout(p=dropout_rate)
111+
self.shortcut = torch.nn.Sequential()
112+
if stride != 1 or in_planes != planes:
113+
self.shortcut = torch.nn.Sequential(
114+
torch.nn.Conv2d(in_planes, planes, kernel_size=1,
115+
stride=stride, bias=True)
116+
)
117+
118+
def forward(self, x):
119+
out = self.dropout(self.conv1(torch.nn.functional.relu(self.bn1(x))))
120+
out = self.conv2(torch.nn.functional.relu(self.bn2(out)))
121+
return out + self.shortcut(x)
122+
123+
"""Aaand the full model with default params set for CIFAR10."""
124+
125+
class WideResNet(torch.nn.Module):
126+
def __init__(self, depth=28, widen_factor=10,
127+
dropout_rate=0.3, num_classes=10):
128+
super(WideResNet, self).__init__()
129+
self.in_planes = 16
130+
n = (depth - 4) // 6
131+
k = widen_factor
132+
nStages = [16, 16*k, 32*k, 64*k]
133+
self.conv1 = conv3x3(3, nStages[0])
134+
self.layer1 = self.wide_layer(WideBasic, nStages[1], n, dropout_rate,
135+
stride=1)
136+
self.layer2 = self.wide_layer(WideBasic, nStages[2], n, dropout_rate,
137+
stride=2)
138+
self.layer3 = self.wide_layer(WideBasic, nStages[3], n, dropout_rate,
139+
stride=2)
140+
self.b1 = torch.nn.BatchNorm2d(nStages[3], momentum=0.9)
141+
self.linear = torch.nn.Linear(nStages[3], num_classes)
142+
143+
def wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
144+
strides = [stride] + [1] * (num_blocks - 1)
145+
layers = []
146+
for stride in strides:
147+
layers.append(block(self.in_planes, planes, dropout_rate, stride))
148+
self.in_planes = planes
149+
return torch.nn.Sequential(*layers)
150+
151+
def forward(self, x):
152+
out = self.conv1(x)
153+
out = self.layer3(self.layer2(self.layer1(out)))
154+
out = torch.nn.functional.relu(self.bn1(out))
155+
out = torch.nn.functional.avg_pool2d(out, 8)
156+
out = out.view(out.size(0), -1)
157+
return self.linear(out)
158+
159+
"""Now that we have the model let's write train and test loaders so that we can pass the model and the data to the MixMatchLoss."""
160+
161+
def basic_generator(x, y=None, batch_size=32, shuffle=True):
162+
i = 0
163+
all_indices = np.random.shuffle(np.arange(len(x))) if shuffle else \
164+
np.arange(len(x))
165+
while(True):
166+
indices = all_indices[i:i+batch_size]
167+
if y is not None:
168+
yield x[indices], y[indices]
169+
yield x[indices]
170+
i = (i + batch_size) % len(x)
171+
172+
def mixmatch_wrapper(x, y, u, model, batch_size=32):
173+
augment_fn = get_augmenter()
174+
train_generator = basic_generator(x, y, batch_size)
175+
unlabeled_generator = basic_generator(u, batch_size=batch_size)
176+
while(True):
177+
xi, yi = next(train_generator)
178+
ui = next(unlabeled_generator)
179+
yield mixmatch(xi, yi, ui, model, augment_fn)
180+
181+
def to_torch(*args, device='cuda'):
182+
convert_fn = lambda x: torch.from_numpy(x).to(device)
183+
return list(map(convert_fn, args))
184+
185+
"""That about covers all the code we need for train and test loaders. Now we can start the training and evaluation. Let's see if all of this works or is just a mess. Going to add basically this same training code from meliketoy's repo but with the MixMatchLoss."""
186+
187+
def test(model, test_gen, test_iters):
188+
acc = []
189+
for i, (x, y) in enumerate(test_gen):
190+
x = to_torch(x)
191+
pred = model(x).to('cpu').argmax(axis=1)
192+
acc.append(np.mean(pred == y.argmax(axis=1)))
193+
if i == test_iters:
194+
break
195+
print('Accuracy was : {}'.format(np.mean(acc)))
196+
197+
def report(loss_history):
198+
print('Average loss in last epoch was : {}'.format(np.mean(loss_history)))
199+
return []
200+
201+
def save(model, iter, train_iters):
202+
torch.save(model.state_dict(), 'model_{}.pth'.format(train_iters // iters))
203+
204+
def run(model, train_gen, test_gen, epochs, train_iters, test_iters, device):
205+
optim = torch.optim.Adam(model.parameters(), lr=lr)
206+
loss_fn = MixMatchLoss()
207+
loss_history = []
208+
for i, (x, u, p, q) in enumerate(train_gen):
209+
if i % train_iters == 0:
210+
loss_history = report(loss_history)
211+
test(model, test_gen, test_iters)
212+
save(model, i, train_iters)
213+
if i // train_iters == epochs:
214+
return
215+
else:
216+
optim.zero_grad()
217+
x, u, p, q = to_torch(x, u, p, q, device=device)
218+
loss = loss_fn(x, u, p, q, model)
219+
loss.backward()
220+
optim.step()
221+
loss_history.append(loss.to('cpu'))

0 commit comments

Comments
 (0)