Skip to content

Commit 18e5d53

Browse files
committed
Created using Colaboratory
1 parent e268869 commit 18e5d53

File tree

1 file changed

+235
-3
lines changed

1 file changed

+235
-3
lines changed

notebook.ipynb

+235-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
"kernelspec": {
1313
"name": "python3",
1414
"display_name": "Python 3"
15-
}
15+
},
16+
"accelerator": "GPU"
1617
},
1718
"cells": [
1819
{
@@ -25,6 +26,18 @@
2526
"<a href=\"https://colab.research.google.com/github/gan3sh500/mixmatch-pytorch/blob/master/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
2627
]
2728
},
29+
{
30+
"cell_type": "markdown",
31+
"metadata": {
32+
"id": "hmES8DZG7pFc",
33+
"colab_type": "text"
34+
},
35+
"source": [
36+
"This notebook tries to implement the MixMatch technique from the [paper](https://arxiv.org/pdf/1905.02249.pdf) MixMatch: A Holistic Approach to Semi-Supervised Learning and recreate their results on CIFAR10 with WideResnet28. \n",
37+
"\n",
38+
"It depends on Pytorch, Numpy and imgaug. The WideResnet28 model code is taken from [meliketoy](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py)'s github repository. Hopefully I can train this on Colab. :)"
39+
]
40+
},
2841
{
2942
"cell_type": "code",
3043
"metadata": {
@@ -35,12 +48,22 @@
3548
"source": [
3649
"import torch\n",
3750
"import numpy as np\n",
38-
"import imgaug as ia\n",
3951
"import imgaug.augmenters as iaa"
4052
],
4153
"execution_count": 0,
4254
"outputs": []
4355
},
56+
{
57+
"cell_type": "markdown",
58+
"metadata": {
59+
"id": "Z_V6d_r-8QUi",
60+
"colab_type": "text"
61+
},
62+
"source": [
63+
"Now that we have the basic imports out of the way lets get to it. \n",
64+
"First we shall define the function to get augmented version of a given batch of images. The below function returns the function to do that. "
65+
]
66+
},
4467
{
4568
"cell_type": "code",
4669
"metadata": {
@@ -62,6 +85,16 @@
6285
"execution_count": 0,
6386
"outputs": []
6487
},
88+
{
89+
"cell_type": "markdown",
90+
"metadata": {
91+
"id": "se8HRC8z8byR",
92+
"colab_type": "text"
93+
},
94+
"source": [
95+
"Next we define the sharpening function to sharpen the prediction from the averaged prediction of all the unlabeled augmented images. It does the same thing as applying a temperature within the softmax function but to the probabilities. "
96+
]
97+
},
6598
{
6699
"cell_type": "code",
67100
"metadata": {
@@ -77,6 +110,16 @@
77110
"execution_count": 0,
78111
"outputs": []
79112
},
113+
{
114+
"cell_type": "markdown",
115+
"metadata": {
116+
"id": "IhvvJUKN80lU",
117+
"colab_type": "text"
118+
},
119+
"source": [
120+
"A simple implementation of the [paper](https://arxiv.org/pdf/1710.09412.pdf) mixup: Beyond Empirical Risk Minimization used in this paper as well."
121+
]
122+
},
80123
{
81124
"cell_type": "code",
82125
"metadata": {
@@ -94,6 +137,16 @@
94137
"execution_count": 0,
95138
"outputs": []
96139
},
140+
{
141+
"cell_type": "markdown",
142+
"metadata": {
143+
"id": "HU0JHbCh90o5",
144+
"colab_type": "text"
145+
},
146+
"source": [
147+
"This covers Algorithm 1 from the paper. "
148+
]
149+
},
97150
{
98151
"cell_type": "code",
99152
"metadata": {
@@ -118,6 +171,16 @@
118171
"execution_count": 0,
119172
"outputs": []
120173
},
174+
{
175+
"cell_type": "markdown",
176+
"metadata": {
177+
"id": "dmSvUmiP94zT",
178+
"colab_type": "text"
179+
},
180+
"source": [
181+
"The combined loss for training from the paper."
182+
]
183+
},
121184
{
122185
"cell_type": "code",
123186
"metadata": {
@@ -136,7 +199,176 @@
136199
" def forward(X, U, p, q):\n",
137200
" X_ = np.concatenate([X, U], axis=1)\n",
138201
" y_ = np.concatenate([p, q], axis=1)\n",
139-
" return self.xent(preds[:len(p)], p) + self.mse(preds[len(p):], q)"
202+
" return self.xent(preds[:len(p)], p) + \\\n",
203+
" self.lambda_u * self.mse(preds[len(p):], q)"
204+
],
205+
"execution_count": 0,
206+
"outputs": []
207+
},
208+
{
209+
"cell_type": "markdown",
210+
"metadata": {
211+
"id": "CCqJtpJ--Cik",
212+
"colab_type": "text"
213+
},
214+
"source": [
215+
"Now that we have the MixMatch stuff done, we have a few things to do. Namely, define the WideResnet28 model, write the data and training code and write testing code. \n",
216+
"Let's start with the model. The below is just a copy paste mostly from the wide-resnet.pytorch repo by meliketoy. "
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"metadata": {
222+
"id": "GIkBy3T15P7l",
223+
"colab_type": "code",
224+
"colab": {}
225+
},
226+
"source": [
227+
"def conv3x3(in_planes, out_planes, stride=1):\n",
228+
" return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
229+
" bias=True)"
230+
],
231+
"execution_count": 0,
232+
"outputs": []
233+
},
234+
{
235+
"cell_type": "markdown",
236+
"metadata": {
237+
"id": "Fud8CmEtCaSN",
238+
"colab_type": "text"
239+
},
240+
"source": [
241+
"Will need the below init function later before training."
242+
]
243+
},
244+
{
245+
"cell_type": "code",
246+
"metadata": {
247+
"id": "FZBBH5EYCZhi",
248+
"colab_type": "code",
249+
"colab": {}
250+
},
251+
"source": [
252+
"\n",
253+
"def conv_init(m):\n",
254+
" classname = m.__class__.__name__\n",
255+
" if classname.find('Conv') != -1:\n",
256+
" torch.nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))\n",
257+
" torch.nn.init.constant(m.bias, 0)\n",
258+
" elif classname.find('BatchNorm') != -1:\n",
259+
" torch.nn.init.constant(m.weight, 1)\n",
260+
" torch.nn.init.constant(m.bias, 0)"
261+
],
262+
"execution_count": 0,
263+
"outputs": []
264+
},
265+
{
266+
"cell_type": "markdown",
267+
"metadata": {
268+
"id": "V_gOfar1CeUx",
269+
"colab_type": "text"
270+
},
271+
"source": [
272+
"The basic block for the WideResnet"
273+
]
274+
},
275+
{
276+
"cell_type": "code",
277+
"metadata": {
278+
"id": "QZ068XQR6LZP",
279+
"colab_type": "code",
280+
"colab": {}
281+
},
282+
"source": [
283+
"class WideBasic(torch.nn.Module):\n",
284+
" def __init__(self, in_planes, planes, dropout_rate, stride=1):\n",
285+
" super(WideBasic, self).__init__()\n",
286+
" self.bn1 = torch.nn.BatchNorm2d(in_planes)\n",
287+
" self.bn2 = torch.nn.BatchNorm2d(planes)\n",
288+
" self.conv1 = torch.nn.Conv2d(in_planes, planes, kernel_size=3,\n",
289+
" padding=1, bias=True)\n",
290+
" self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,\n",
291+
" padding=1, bias=True)\n",
292+
" self.dropout = torch.nn.Dropout(p=dropout_rate)\n",
293+
" self.shortcut = torch.nn.Sequential()\n",
294+
" if stride != 1 or in_planes != planes:\n",
295+
" self.shortcut = torch.nn.Sequential(\n",
296+
" torch.nn.Conv2d(in_planes, planes, kernel_size=1,\n",
297+
" stride=stride, bias=True)\n",
298+
" )\n",
299+
"\n",
300+
" def forward(self, x):\n",
301+
" out = self.dropout(self.conv1(torch.nn.functional.relu(self.bn1(x))))\n",
302+
" out = self.conv2(torch.nn.functional.relu(self.bn2(out)))\n",
303+
" return out + self.shortcut(x)"
304+
],
305+
"execution_count": 0,
306+
"outputs": []
307+
},
308+
{
309+
"cell_type": "markdown",
310+
"metadata": {
311+
"id": "wdew7GNoChmh",
312+
"colab_type": "text"
313+
},
314+
"source": [
315+
"Aaand the full model with default params set for CIFAR10."
316+
]
317+
},
318+
{
319+
"cell_type": "code",
320+
"metadata": {
321+
"id": "YvE9l4W27jTx",
322+
"colab_type": "code",
323+
"colab": {}
324+
},
325+
"source": [
326+
"class WideResNet(torch.nn.Module):\n",
327+
" def __init__(self, depth=28, widen_factor=10,\n",
328+
" dropout_rate=0.3, num_classes=10):\n",
329+
" super(WideResNet, self).__init__()\n",
330+
" self.in_planes = 16\n",
331+
" n = (depth - 4) // 6\n",
332+
" k = widen_factor\n",
333+
" nStages = [16, 16*k, 32*k, 64*k]\n",
334+
" self.conv1 = conv3x3(3, nStages[0])\n",
335+
" self.layer1 = self.wide_layer(WideBasic, nStages[1], n, dropout_rate,\n",
336+
" stride=1)\n",
337+
" self.layer2 = self.wide_layer(WideBasic, nStages[2], n, dropout_rate,\n",
338+
" stride=2)\n",
339+
" self.layer3 = self.wide_layer(WideBasic, nStages[3], n, dropout_rate,\n",
340+
" stride=2)\n",
341+
" self.b1 = torch.nn.BatchNorm2d(nStages[3], momentum=0.9)\n",
342+
" self.linear = torch.nn.Linear(nStages[3], num_classes)\n",
343+
" \n",
344+
" def wide_layer(self, block, planes, num_blocks, dropout_rate, stride):\n",
345+
" strides = [stride] + [1] * (num_blocks - 1)\n",
346+
" layers = []\n",
347+
" for stride in strides:\n",
348+
" layers.append(block(self.in_planes, planes, dropout_rate, stride))\n",
349+
" self.in_planes = planes\n",
350+
" return torch.nn.Sequential(*layers)\n",
351+
" \n",
352+
" def forward(self, x):\n",
353+
" out = self.conv1(x)\n",
354+
" out = self.layer3(self.layer2(self.layer1(out)))\n",
355+
" out = torch.nn.functional.relu(self.bn1(out))\n",
356+
" out = torch.nn.functional.avg_pool2d(out, 8)\n",
357+
" out = out.view(out.size(0), -1)\n",
358+
" return self.linear(out)"
359+
],
360+
"execution_count": 0,
361+
"outputs": []
362+
},
363+
{
364+
"cell_type": "code",
365+
"metadata": {
366+
"id": "EjCTPM8wB-dR",
367+
"colab_type": "code",
368+
"colab": {}
369+
},
370+
"source": [
371+
""
140372
],
141373
"execution_count": 0,
142374
"outputs": []

0 commit comments

Comments
 (0)