|
12 | 12 | "kernelspec": {
|
13 | 13 | "name": "python3",
|
14 | 14 | "display_name": "Python 3"
|
15 |
| - } |
| 15 | + }, |
| 16 | + "accelerator": "GPU" |
16 | 17 | },
|
17 | 18 | "cells": [
|
18 | 19 | {
|
|
25 | 26 | "<a href=\"https://colab.research.google.com/github/gan3sh500/mixmatch-pytorch/blob/master/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
26 | 27 | ]
|
27 | 28 | },
|
| 29 | + { |
| 30 | + "cell_type": "markdown", |
| 31 | + "metadata": { |
| 32 | + "id": "hmES8DZG7pFc", |
| 33 | + "colab_type": "text" |
| 34 | + }, |
| 35 | + "source": [ |
| 36 | + "This notebook tries to implement the MixMatch technique from the [paper](https://arxiv.org/pdf/1905.02249.pdf) MixMatch: A Holistic Approach to Semi-Supervised Learning and recreate their results on CIFAR10 with WideResnet28. \n", |
| 37 | + "\n", |
| 38 | + "It depends on Pytorch, Numpy and imgaug. The WideResnet28 model code is taken from [meliketoy](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py)'s github repository. Hopefully I can train this on Colab. :)" |
| 39 | + ] |
| 40 | + }, |
28 | 41 | {
|
29 | 42 | "cell_type": "code",
|
30 | 43 | "metadata": {
|
|
35 | 48 | "source": [
|
36 | 49 | "import torch\n",
|
37 | 50 | "import numpy as np\n",
|
38 |
| - "import imgaug as ia\n", |
39 | 51 | "import imgaug.augmenters as iaa"
|
40 | 52 | ],
|
41 | 53 | "execution_count": 0,
|
42 | 54 | "outputs": []
|
43 | 55 | },
|
| 56 | + { |
| 57 | + "cell_type": "markdown", |
| 58 | + "metadata": { |
| 59 | + "id": "Z_V6d_r-8QUi", |
| 60 | + "colab_type": "text" |
| 61 | + }, |
| 62 | + "source": [ |
| 63 | + "Now that we have the basic imports out of the way lets get to it. \n", |
| 64 | + "First we shall define the function to get augmented version of a given batch of images. The below function returns the function to do that. " |
| 65 | + ] |
| 66 | + }, |
44 | 67 | {
|
45 | 68 | "cell_type": "code",
|
46 | 69 | "metadata": {
|
|
62 | 85 | "execution_count": 0,
|
63 | 86 | "outputs": []
|
64 | 87 | },
|
| 88 | + { |
| 89 | + "cell_type": "markdown", |
| 90 | + "metadata": { |
| 91 | + "id": "se8HRC8z8byR", |
| 92 | + "colab_type": "text" |
| 93 | + }, |
| 94 | + "source": [ |
| 95 | + "Next we define the sharpening function to sharpen the prediction from the averaged prediction of all the unlabeled augmented images. It does the same thing as applying a temperature within the softmax function but to the probabilities. " |
| 96 | + ] |
| 97 | + }, |
65 | 98 | {
|
66 | 99 | "cell_type": "code",
|
67 | 100 | "metadata": {
|
|
77 | 110 | "execution_count": 0,
|
78 | 111 | "outputs": []
|
79 | 112 | },
|
| 113 | + { |
| 114 | + "cell_type": "markdown", |
| 115 | + "metadata": { |
| 116 | + "id": "IhvvJUKN80lU", |
| 117 | + "colab_type": "text" |
| 118 | + }, |
| 119 | + "source": [ |
| 120 | + "A simple implementation of the [paper](https://arxiv.org/pdf/1710.09412.pdf) mixup: Beyond Empirical Risk Minimization used in this paper as well." |
| 121 | + ] |
| 122 | + }, |
80 | 123 | {
|
81 | 124 | "cell_type": "code",
|
82 | 125 | "metadata": {
|
|
94 | 137 | "execution_count": 0,
|
95 | 138 | "outputs": []
|
96 | 139 | },
|
| 140 | + { |
| 141 | + "cell_type": "markdown", |
| 142 | + "metadata": { |
| 143 | + "id": "HU0JHbCh90o5", |
| 144 | + "colab_type": "text" |
| 145 | + }, |
| 146 | + "source": [ |
| 147 | + "This covers Algorithm 1 from the paper. " |
| 148 | + ] |
| 149 | + }, |
97 | 150 | {
|
98 | 151 | "cell_type": "code",
|
99 | 152 | "metadata": {
|
|
118 | 171 | "execution_count": 0,
|
119 | 172 | "outputs": []
|
120 | 173 | },
|
| 174 | + { |
| 175 | + "cell_type": "markdown", |
| 176 | + "metadata": { |
| 177 | + "id": "dmSvUmiP94zT", |
| 178 | + "colab_type": "text" |
| 179 | + }, |
| 180 | + "source": [ |
| 181 | + "The combined loss for training from the paper." |
| 182 | + ] |
| 183 | + }, |
121 | 184 | {
|
122 | 185 | "cell_type": "code",
|
123 | 186 | "metadata": {
|
|
136 | 199 | " def forward(X, U, p, q):\n",
|
137 | 200 | " X_ = np.concatenate([X, U], axis=1)\n",
|
138 | 201 | " y_ = np.concatenate([p, q], axis=1)\n",
|
139 |
| - " return self.xent(preds[:len(p)], p) + self.mse(preds[len(p):], q)" |
| 202 | + " return self.xent(preds[:len(p)], p) + \\\n", |
| 203 | + " self.lambda_u * self.mse(preds[len(p):], q)" |
| 204 | + ], |
| 205 | + "execution_count": 0, |
| 206 | + "outputs": [] |
| 207 | + }, |
| 208 | + { |
| 209 | + "cell_type": "markdown", |
| 210 | + "metadata": { |
| 211 | + "id": "CCqJtpJ--Cik", |
| 212 | + "colab_type": "text" |
| 213 | + }, |
| 214 | + "source": [ |
| 215 | + "Now that we have the MixMatch stuff done, we have a few things to do. Namely, define the WideResnet28 model, write the data and training code and write testing code. \n", |
| 216 | + "Let's start with the model. The below is just a copy paste mostly from the wide-resnet.pytorch repo by meliketoy. " |
| 217 | + ] |
| 218 | + }, |
| 219 | + { |
| 220 | + "cell_type": "code", |
| 221 | + "metadata": { |
| 222 | + "id": "GIkBy3T15P7l", |
| 223 | + "colab_type": "code", |
| 224 | + "colab": {} |
| 225 | + }, |
| 226 | + "source": [ |
| 227 | + "def conv3x3(in_planes, out_planes, stride=1):\n", |
| 228 | + " return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", |
| 229 | + " bias=True)" |
| 230 | + ], |
| 231 | + "execution_count": 0, |
| 232 | + "outputs": [] |
| 233 | + }, |
| 234 | + { |
| 235 | + "cell_type": "markdown", |
| 236 | + "metadata": { |
| 237 | + "id": "Fud8CmEtCaSN", |
| 238 | + "colab_type": "text" |
| 239 | + }, |
| 240 | + "source": [ |
| 241 | + "Will need the below init function later before training." |
| 242 | + ] |
| 243 | + }, |
| 244 | + { |
| 245 | + "cell_type": "code", |
| 246 | + "metadata": { |
| 247 | + "id": "FZBBH5EYCZhi", |
| 248 | + "colab_type": "code", |
| 249 | + "colab": {} |
| 250 | + }, |
| 251 | + "source": [ |
| 252 | + "\n", |
| 253 | + "def conv_init(m):\n", |
| 254 | + " classname = m.__class__.__name__\n", |
| 255 | + " if classname.find('Conv') != -1:\n", |
| 256 | + " torch.nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))\n", |
| 257 | + " torch.nn.init.constant(m.bias, 0)\n", |
| 258 | + " elif classname.find('BatchNorm') != -1:\n", |
| 259 | + " torch.nn.init.constant(m.weight, 1)\n", |
| 260 | + " torch.nn.init.constant(m.bias, 0)" |
| 261 | + ], |
| 262 | + "execution_count": 0, |
| 263 | + "outputs": [] |
| 264 | + }, |
| 265 | + { |
| 266 | + "cell_type": "markdown", |
| 267 | + "metadata": { |
| 268 | + "id": "V_gOfar1CeUx", |
| 269 | + "colab_type": "text" |
| 270 | + }, |
| 271 | + "source": [ |
| 272 | + "The basic block for the WideResnet" |
| 273 | + ] |
| 274 | + }, |
| 275 | + { |
| 276 | + "cell_type": "code", |
| 277 | + "metadata": { |
| 278 | + "id": "QZ068XQR6LZP", |
| 279 | + "colab_type": "code", |
| 280 | + "colab": {} |
| 281 | + }, |
| 282 | + "source": [ |
| 283 | + "class WideBasic(torch.nn.Module):\n", |
| 284 | + " def __init__(self, in_planes, planes, dropout_rate, stride=1):\n", |
| 285 | + " super(WideBasic, self).__init__()\n", |
| 286 | + " self.bn1 = torch.nn.BatchNorm2d(in_planes)\n", |
| 287 | + " self.bn2 = torch.nn.BatchNorm2d(planes)\n", |
| 288 | + " self.conv1 = torch.nn.Conv2d(in_planes, planes, kernel_size=3,\n", |
| 289 | + " padding=1, bias=True)\n", |
| 290 | + " self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,\n", |
| 291 | + " padding=1, bias=True)\n", |
| 292 | + " self.dropout = torch.nn.Dropout(p=dropout_rate)\n", |
| 293 | + " self.shortcut = torch.nn.Sequential()\n", |
| 294 | + " if stride != 1 or in_planes != planes:\n", |
| 295 | + " self.shortcut = torch.nn.Sequential(\n", |
| 296 | + " torch.nn.Conv2d(in_planes, planes, kernel_size=1,\n", |
| 297 | + " stride=stride, bias=True)\n", |
| 298 | + " )\n", |
| 299 | + "\n", |
| 300 | + " def forward(self, x):\n", |
| 301 | + " out = self.dropout(self.conv1(torch.nn.functional.relu(self.bn1(x))))\n", |
| 302 | + " out = self.conv2(torch.nn.functional.relu(self.bn2(out)))\n", |
| 303 | + " return out + self.shortcut(x)" |
| 304 | + ], |
| 305 | + "execution_count": 0, |
| 306 | + "outputs": [] |
| 307 | + }, |
| 308 | + { |
| 309 | + "cell_type": "markdown", |
| 310 | + "metadata": { |
| 311 | + "id": "wdew7GNoChmh", |
| 312 | + "colab_type": "text" |
| 313 | + }, |
| 314 | + "source": [ |
| 315 | + "Aaand the full model with default params set for CIFAR10." |
| 316 | + ] |
| 317 | + }, |
| 318 | + { |
| 319 | + "cell_type": "code", |
| 320 | + "metadata": { |
| 321 | + "id": "YvE9l4W27jTx", |
| 322 | + "colab_type": "code", |
| 323 | + "colab": {} |
| 324 | + }, |
| 325 | + "source": [ |
| 326 | + "class WideResNet(torch.nn.Module):\n", |
| 327 | + " def __init__(self, depth=28, widen_factor=10,\n", |
| 328 | + " dropout_rate=0.3, num_classes=10):\n", |
| 329 | + " super(WideResNet, self).__init__()\n", |
| 330 | + " self.in_planes = 16\n", |
| 331 | + " n = (depth - 4) // 6\n", |
| 332 | + " k = widen_factor\n", |
| 333 | + " nStages = [16, 16*k, 32*k, 64*k]\n", |
| 334 | + " self.conv1 = conv3x3(3, nStages[0])\n", |
| 335 | + " self.layer1 = self.wide_layer(WideBasic, nStages[1], n, dropout_rate,\n", |
| 336 | + " stride=1)\n", |
| 337 | + " self.layer2 = self.wide_layer(WideBasic, nStages[2], n, dropout_rate,\n", |
| 338 | + " stride=2)\n", |
| 339 | + " self.layer3 = self.wide_layer(WideBasic, nStages[3], n, dropout_rate,\n", |
| 340 | + " stride=2)\n", |
| 341 | + " self.b1 = torch.nn.BatchNorm2d(nStages[3], momentum=0.9)\n", |
| 342 | + " self.linear = torch.nn.Linear(nStages[3], num_classes)\n", |
| 343 | + " \n", |
| 344 | + " def wide_layer(self, block, planes, num_blocks, dropout_rate, stride):\n", |
| 345 | + " strides = [stride] + [1] * (num_blocks - 1)\n", |
| 346 | + " layers = []\n", |
| 347 | + " for stride in strides:\n", |
| 348 | + " layers.append(block(self.in_planes, planes, dropout_rate, stride))\n", |
| 349 | + " self.in_planes = planes\n", |
| 350 | + " return torch.nn.Sequential(*layers)\n", |
| 351 | + " \n", |
| 352 | + " def forward(self, x):\n", |
| 353 | + " out = self.conv1(x)\n", |
| 354 | + " out = self.layer3(self.layer2(self.layer1(out)))\n", |
| 355 | + " out = torch.nn.functional.relu(self.bn1(out))\n", |
| 356 | + " out = torch.nn.functional.avg_pool2d(out, 8)\n", |
| 357 | + " out = out.view(out.size(0), -1)\n", |
| 358 | + " return self.linear(out)" |
| 359 | + ], |
| 360 | + "execution_count": 0, |
| 361 | + "outputs": [] |
| 362 | + }, |
| 363 | + { |
| 364 | + "cell_type": "code", |
| 365 | + "metadata": { |
| 366 | + "id": "EjCTPM8wB-dR", |
| 367 | + "colab_type": "code", |
| 368 | + "colab": {} |
| 369 | + }, |
| 370 | + "source": [ |
| 371 | + "" |
140 | 372 | ],
|
141 | 373 | "execution_count": 0,
|
142 | 374 | "outputs": []
|
|
0 commit comments