{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#export\n", "from exp.nb_05 import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to notebook introduction in lesson 10 video](https://course.fast.ai/videos/?lesson=10&t=3167)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Early stopping" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Better callback cancellation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 10 video](https://course.fast.ai/videos/?lesson=10&t=3230)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# get the data, set the loss function\n", "x_train,y_train,x_valid,y_valid = get_data()\n", "train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n", "n_hidden,batch_size = 50,512\n", "n_out = y_train.max().item()+1\n", "loss_func = F.cross_entropy" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# create a DataBunch\n", "data = DataBunch(*get_dls(train_ds, valid_ds, batch_size), n_out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Slightly refactor Callback() and add three Cancellation callbacks" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#export\n", "\n", "# Callback() class is slightly refactored from notebook 04_callbacks\n", "class Callback():\n", " \n", " # initialize _order to zero. \n", " _order=0\n", " \n", " # set_runner() method takes a callback as an input\n", " # note that initially self.run is unset -- there is no default value \n", " def set_runner(self, run): \n", " self.run=run\n", " def __getattr__(self, callback_name): \n", " return getattr(self.run, callback_name)\n", " \n", " # set the callback name property\n", " # if the callback doesn't have a name, set the callback name property to 'callback'\n", " @property\n", " def name(self):\n", " name = re.sub(r'Callback$', '', self.__class__.__name__)\n", " return camel2snake(name or 'callback')\n", " \n", " # this is the only modification to the 04_callbacks notebook\n", " # it allows the Callback() class to be called as a function\n", " def __call__(self, callback_name):\n", " f = getattr(self, callback_name, None)\n", " # check this callback name, and return True if it is the requested callback\n", " if f and f(): return True\n", " return False\n", "\n", "# this helper callback is used in Runner()\n", "class TrainEvalCallback(Callback):\n", " \n", " # initialize the epoch, batch, and iteration counters\n", " def begin_fit(self):\n", " self.run.n_epoch_float=0.\n", " self.run.n_batch = 0\n", " self.run.n_iter=0\n", " \n", " # if we are in the training phase, update the epoch and batch counters\n", " def after_batch(self):\n", " if not self.in_train: \n", " return\n", " # each batch represents a fraction of an epoch\n", " self.run.n_epoch_float += 1./self.n_batches\n", " self.run.n_batch += 1\n", " \n", " # execute the training phase\n", " def begin_epoch(self):\n", " self.run.n_epoch_float=self.n_epoch_float\n", " self.model.train()\n", " self.run.in_train=True\n", "\n", " # execute the prediction phase\n", " def begin_validate(self):\n", " self.model.eval()\n", " self.run.in_train=False\n", "\n", "# add three cancellation callbacks\n", "class CancelTrainException(Exception): pass\n", "class CancelEpochException(Exception): pass\n", "class CancelBatchException(Exception): pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Refactor Runner() to use the cancellation callbacks" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Runner():\n", "\n", " # initialize by setting the stop Flag to False, and constructing a list of callbacks from the inputs\n", " def __init__(self, callbacks=None, callback_funcs=None):\n", " # inputs are two lists: callbacks and callback_funcs\n", " # Q: it's not clear why we need two lists rather than one\n", " # create a list of callbacks from the input callbacks\n", " self.in_train = False\n", " callbacks = listify(callbacks)\n", " # associate each callback_func() to its snake case callback name then append it to the callbacks list\n", " for callback_func in listify(callback_funcs):\n", " callback = callback_func()\n", " setattr(self, callback.name, callback)\n", " callbacks.append(callback)\n", " # set the stopping flag to `False` and append TrainEvalCallback() to the callbacks list\n", " self.stop,self.callbacks = False,[TrainEvalCallback()]+callbacks\n", "\n", " # get the properties of the Learner object\n", " @property\n", " def opt(self): return self.learn.opt\n", " @property\n", " def model(self): return self.learn.model\n", " @property\n", " def loss_func(self): return self.learn.loss_func\n", " @property\n", " def data(self): return self.learn.data\n", " \n", " \n", " # method to process a single batch\n", " def one_batch(self, xb, yb):\n", " try:\n", " self.xb,self.yb = xb,yb\n", " self('begin_batch')\n", " self.pred = self.model(self.xb)\n", " self('after_pred')\n", " self.loss = self.loss_func(self.pred, self.yb)\n", " self('after_loss')\n", " if not self.in_train: return\n", " self.loss.backward()\n", " self('after_backward')\n", " self.opt.step()\n", " self('after_step')\n", " self.opt.zero_grad()\n", " except CancelBatchException: self('after_cancel_batch')\n", " finally: self('after_batch')\n", "\n", " # method to process all batches\n", " def all_batches(self, dataloader):\n", " # total number of batches in an epoch\n", " # self.n_epoch_float = 0.\n", " self.n_batches = len(dataloader)\n", " try:\n", " for xb,yb in dataloader: self.one_batch(xb, yb)\n", " except CancelEpochException: self('after_cancel_epoch')\n", "\n", " # method to process training or validation data\n", " def fit(self, learn, n_epochs,):\n", " self.n_epochs,self.learn,self.loss = n_epochs,learn,tensor(0.)\n", "\n", " try:\n", " for callback in self.callbacks: \n", " callback.set_runner(self)\n", " self('begin_fit')\n", " for epoch_number in range(n_epochs):\n", " self.epoch_number = epoch_number\n", " \n", " \n", " # training phase\n", " if not self('begin_epoch'): \n", " self.all_batches(self.data.train_dl)\n", "\n", " # validation phase\n", " with torch.no_grad(): \n", " if not self('begin_validate'): \n", " self.all_batches(self.data.valid_dl)\n", " self('after_epoch')\n", " \n", " except CancelTrainException: \n", " self('after_cancel_train')\n", " \n", " finally:\n", " # set the `after_fit` state to `True`\n", " self('after_fit')\n", " # erase the Learner object\n", " self.learn = None\n", "\n", " def __call__(self, callback_name):\n", " # __call__ allows an instance of this class to be called as a function\n", " # Q: note clear what this loop is trying to do; it always returns result = False\n", " result = False\n", " for callback in sorted(self.callbacks, key=lambda x: x._order): \n", " result = callback(callback_name) and result\n", " return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Refactor TestCallback() to use a cancellation callback" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class TestCallback(Callback):\n", " _order=1\n", " def after_step(self):\n", " self.n_iter += 1\n", " print(self.n_iter)\n", " if self.n_iter>=10: \n", " raise CancelTrainException()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "learn = create_learner(get_model, loss_func, data)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "run = Runner(callback_funcs=TestCallback)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "2\n", "3\n", "4\n", "5\n", "6\n", "7\n", "8\n", "9\n", "10\n" ] } ], "source": [ "run.fit(learn, n_epochs = 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. AvgStatsCallback, Recorder, and ParamScheduler Callbacks" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AvgStatsCallback(Callback):\n", " def __init__(self, metrics):\n", " self.train_stats,self.valid_stats = AvgStats(metrics,in_train = True),AvgStats(metrics,in_train = False)\n", " \n", " # initialize train_stats and valid_stats\n", " def begin_epoch(self):\n", " self.train_stats.reset()\n", " self.valid_stats.reset()\n", " \n", " # compute and accumulate stats after the loss function has been evaluated\n", " def after_loss(self):\n", " stats = self.train_stats if self.in_train else self.valid_stats\n", " with torch.no_grad(): stats.accumulate(self.run)\n", " \n", " # print stats after the epoch has been processed\n", " def after_epoch(self):\n", " print(self.train_stats)\n", " print(self.valid_stats)\n", " \n", "class Recorder(Callback):\n", " def begin_fit(self):\n", " self.lrs = [[] for _ in self.opt.param_groups]\n", " self.losses = []\n", "\n", " def after_batch(self):\n", " if not self.in_train: return\n", " for pg,lr in zip(self.opt.param_groups,self.lrs): lr.append(pg['lr'])\n", " self.losses.append(self.loss.detach().cpu()) \n", "\n", " def plot_lr (self, pgid=-1): \n", " plt.plot(self.lrs[pgid])\n", " plt.xlabel('iteration')\n", " plt.ylabel('loss')\n", " def plot_loss(self, skip_last=0): # !!!!! not used\n", " plt.plot(self.losses[:len(self.losses)-skip_last])\n", " plt.xlabel('iteration')\n", " plt.ylabel('loss')\n", "\n", " \n", " def plot(self, skip_last=0, pgid=-1):\n", " losses = [o.item() for o in self.losses]\n", " lrs = self.lrs[pgid]\n", " n = len(losses)-skip_last\n", " plt.xscale('log')\n", " plt.plot(lrs[:n], losses[:n])\n", " plt.xlabel('learning rate')\n", " plt.ylabel('loss')\n", "\n", "class ParamScheduler(Callback):\n", " _order=1\n", " def __init__(self, pname, sched_funcs): \n", " self.pname,self.sched_funcs = pname,sched_funcs\n", " \n", " def begin_fit(self):\n", " if not isinstance(self.sched_funcs, (list,tuple)):\n", " self.sched_funcs = [self.sched_funcs] * len(self.opt.param_groups)\n", "\n", " def set_param(self):\n", " assert len(self.opt.param_groups)==len(self.sched_funcs)\n", " for pg,f in zip(self.opt.param_groups,self.sched_funcs):\n", " # !!!!! this is wrong -- \n", " pg[self.pname] = f(self.n_epochs/self.n_epochs)\n", " \n", " def begin_batch(self): \n", " if self.in_train: self.set_param()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. Learning Rate Finder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "NB: You may want to also add something that saves the model before running this, and loads it back after running - otherwise you'll lose your weights!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 10 video](https://course.fast.ai/videos/?lesson=10&t=3545)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class LR_Find(Callback):\n", " _order=1\n", " def __init__(self, max_iter=100, min_lr=1e-6, max_lr=10):\n", " self.max_iter,self.min_lr,self.max_lr = max_iter,min_lr,max_lr\n", " self.best_loss = 1e9\n", " \n", " def begin_batch(self): \n", " if not self.in_train: \n", " return\n", " pos = self.n_iter/self.max_iter\n", " lr = self.min_lr * (self.max_lr/self.min_lr) ** pos\n", " for pg in self.opt.param_groups: pg['lr'] = lr\n", " \n", " def after_step(self):\n", " self.n_iter += 1\n", " if self.n_iter >= self.max_iter or self.loss > self.best_loss*10:\n", " raise CancelTrainException()\n", " if self.loss < self.best_loss: \n", " self.best_loss = self.loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "NB: In fastai we also use exponential smoothing on the loss. For that reason we check for `best_loss*3` instead of `best_loss*10`." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# instantiate a Learner object with data, loss_func, opt and model\n", "learn = Learner(*get_model(data), loss_func, data) " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# instantiate a Runner object using the callback_funcs() input\n", "run = Runner(callback_funcs=[LR_Find, Recorder])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# run a training/validation loop\n", "run.fit(learn, n_epochs=2)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot learning rate vs. loss\n", "run.recorder.plot(skip_last=5)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot iteration number vs. loss\n", "run.recorder.plot_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Export" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 05b_early_stopping_jcat.ipynb to exp\\nb_05b.py\n" ] } ], "source": [ "!python notebook2script.py 05b_early_stopping_jcat.ipynb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }