Skip to content

Commit 0b7833d

Browse files
authored
Merge pull request gan3sh500#4 from fperdigon/master
Fixing stupid mistakes.
2 parents 83fdf39 + 4f5402b commit 0b7833d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

notebook.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@
199199
"def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):\n",
200200
" xb = augment_fn(x)\n",
201201
" ub = [augment_fn(u) for _ in range(K)]\n",
202-
" qb = sharpen(sum(map(lambda i: model(i), ub)) / K)\n",
202+
" qb = sharpen(sum(map(lambda i: model(i), ub)) / K, T)\n",
203203
" Ux = np.concatenate(ub, axis=0)\n",
204204
" Uy = np.concatenate([qb for _ in range(K)], axis=0)\n",
205205
" indices = np.random.shuffle(np.arange(len(xb) + len(Ux)))\n",
@@ -237,7 +237,7 @@
237237
" self.mse = torch.nn.MSELoss()\n",
238238
" super(MixMatchLoss, self).__init__()\n",
239239
" \n",
240-
" def forward(X, U, p, q, model):\n",
240+
" def forward(self, X, U, p, q, model):\n",
241241
" X_ = np.concatenate([X, U], axis=1)\n",
242242
" preds = model(X_)\n",
243243
" return self.xent(preds[:len(p)], p) + \\\n",
@@ -482,4 +482,4 @@
482482
"outputs": []
483483
}
484484
]
485-
}
485+
}

0 commit comments

Comments
 (0)