You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<a href="https://colab.research.google.com/github/gan3sh500/mixmatch-pytorch/blob/master/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
10
+
11
+
This notebook tries to implement the MixMatch technique from the [paper](https://arxiv.org/pdf/1905.02249.pdf) MixMatch: A Holistic Approach to Semi-Supervised Learning and recreate their results on CIFAR10 with WideResnet28.
12
+
13
+
It depends on Pytorch, Numpy and imgaug. The WideResnet28 model code is taken from [meliketoy](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py)'s github repository. Hopefully I can train this on Colab with a Tesla T4. :)
14
+
"""
15
+
16
+
!nvidia-smi
17
+
18
+
importtorch
19
+
importnumpyasnp
20
+
importimgaug.augmentersasiaa
21
+
22
+
"""Now that we have the basic imports out of the way lets get to it.
23
+
First we shall define the function to get augmented version of a given batch of images. The below function returns the function to do that.
"""Next we define the sharpening function to sharpen the prediction from the averaged prediction of all the unlabeled augmented images. It does the same thing as applying a temperature within the softmax function but to the probabilities."""
37
+
38
+
defsharpen(x, T):
39
+
temp=x**(1/T)
40
+
returntemp/temp.sum(axis=1, keepdims=True)
41
+
42
+
"""A simple implementation of the [paper](https://arxiv.org/pdf/1710.09412.pdf) mixup: Beyond Empirical Risk Minimization used in this paper as well."""
43
+
44
+
defmixup(x1, x2, y1, y2, alpha):
45
+
beta=np.random.beta(alpha, -alpha)
46
+
x=beta*x1+ (1-beta) *x2
47
+
y=beta*y1+ (1-beta) *y2
48
+
returnx, y
49
+
50
+
"""This covers Algorithm 1 from the paper."""
51
+
52
+
defmixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):
X, p=mixup(xb, Wx[:len(xb)], y, Wy[:len(xb)], alpha)
62
+
U, q=mixup(Ux, Wx[len(xb):], Uy, Wy[len(xb):], alpha)
63
+
returnX, U, p, q
64
+
65
+
"""The combined loss for training from the paper."""
66
+
67
+
classMixMatchLoss(torch.nn.Module):
68
+
def__init__(self, lambda_u=100):
69
+
self.lambda_u=lambda_u
70
+
self.xent=torch.nn.CrossEntropyLoss()
71
+
self.mse=torch.nn.MSELoss()
72
+
super(MixMatchLoss, self).__init__()
73
+
74
+
defforward(self, X, U, p, q, model):
75
+
X_=np.concatenate([X, U], axis=1)
76
+
preds=model(X_)
77
+
returnself.xent(preds[:len(p)], p) + \
78
+
self.lambda_u*self.mse(preds[len(p):], q)
79
+
80
+
"""Now that we have the MixMatch stuff done, we have a few things to do. Namely, define the WideResnet28 model, write the data and training code and write testing code.
81
+
Let's start with the model. The below is just a copy paste mostly from the wide-resnet.pytorch repo by meliketoy.
"""That about covers all the code we need for train and test loaders. Now we can start the training and evaluation. Let's see if all of this works or is just a mess. Going to add basically this same training code from meliketoy's repo but with the MixMatchLoss."""
186
+
187
+
deftest(model, test_gen, test_iters):
188
+
acc= []
189
+
fori, (x, y) inenumerate(test_gen):
190
+
x=to_torch(x)
191
+
pred=model(x).to('cpu').argmax(axis=1)
192
+
acc.append(np.mean(pred==y.argmax(axis=1)))
193
+
ifi==test_iters:
194
+
break
195
+
print('Accuracy was : {}'.format(np.mean(acc)))
196
+
197
+
defreport(loss_history):
198
+
print('Average loss in last epoch was : {}'.format(np.mean(loss_history)))
0 commit comments