diff --git a/week6/week6_final_project_image_captioning_clean.ipynb b/week6/week6_final_project_image_captioning_clean.ipynb index 11be9da..4bd08ed 100644 --- a/week6/week6_final_project_image_captioning_clean.ipynb +++ b/week6/week6_final_project_image_captioning_clean.ipynb @@ -33,8 +33,7 @@ "ExecuteTime": { "end_time": "2017-09-17T12:30:35.584796Z", "start_time": "2017-09-17T12:30:35.581343Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -47,9 +46,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "download_utils.link_all_keras_resources()" @@ -62,8 +59,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:32:05.229736Z", "start_time": "2017-09-17T14:31:56.495874Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -88,6 +84,47 @@ "import tqdm_utils" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare the storage for model checkpoints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Leave USE_GOOGLE_DRIVE = False if you're running locally!\n", + "# We recommend to set USE_GOOGLE_DRIVE = True in Google Colab!\n", + "# If set to True, we will mount Google Drive, so that you can restore your checkpoint \n", + "# and continue trainig even if your previous Colab session dies.\n", + "# If set to True, follow on-screen instructions to access Google Drive (you must have a Google account).\n", + "USE_GOOGLE_DRIVE = False\n", + "\n", + "def mount_google_drive():\n", + " from google.colab import drive\n", + " mount_directory = \"/content/gdrive\"\n", + " drive.mount(mount_directory)\n", + " drive_root = mount_directory + \"/\" + list(filter(lambda x: x[0] != '.', os.listdir(mount_directory)))[0] + \"/colab\"\n", + " return drive_root\n", + "\n", + "CHECKPOINT_ROOT = \"\"\n", + "if USE_GOOGLE_DRIVE:\n", + " CHECKPOINT_ROOT = mount_google_drive() + \"/\"\n", + "\n", + "def get_checkpoint_path(epoch=None):\n", + " if epoch is None:\n", + " return os.path.abspath(CHECKPOINT_ROOT + \"weights\")\n", + " else:\n", + " return os.path.abspath(CHECKPOINT_ROOT + \"weights_{}\".format(epoch))\n", + " \n", + "# example of checkpoint dir\n", + "print(get_checkpoint_path(10))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -99,9 +136,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "grader = grading.Grader(assignment_key=\"NEDBg6CgEee8nQ6uE8a7OA\", \n", @@ -111,9 +146,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "# token expires every 30 min\n", @@ -138,9 +171,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "# we downloaded them for you, just link them here\n", @@ -170,8 +201,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:32:09.629321Z", "start_time": "2017-09-17T14:32:09.627108Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -185,8 +215,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:32:09.836606Z", "start_time": "2017-09-17T14:32:09.831028Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -247,8 +276,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:32:12.621413Z", "start_time": "2017-09-17T14:32:11.986281Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -269,8 +297,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:32:21.515330Z", "start_time": "2017-09-17T14:32:21.400879Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -292,8 +319,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:32:24.897276Z", "start_time": "2017-09-17T14:32:22.942805Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -326,8 +352,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:42:06.492565Z", "start_time": "2017-09-17T14:42:06.245458Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -363,8 +388,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:43:40.637447Z", "start_time": "2017-09-17T14:43:40.633717Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -379,8 +403,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:43:40.932131Z", "start_time": "2017-09-17T14:43:40.891187Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -445,8 +468,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:43:44.824532Z", "start_time": "2017-09-17T14:43:41.264769Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -463,8 +485,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:43:53.206639Z", "start_time": "2017-09-17T14:43:44.826028Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -489,8 +510,7 @@ "ExecuteTime": { "end_time": "2017-09-17T16:11:52.425546Z", "start_time": "2017-09-17T16:11:52.414004Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -524,8 +544,7 @@ "ExecuteTime": { "end_time": "2017-09-17T16:12:02.051692Z", "start_time": "2017-09-17T16:12:02.045821Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -541,9 +560,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "# you can make submission with answers so far to check yourself at this stage\n", @@ -602,8 +619,7 @@ "ExecuteTime": { "end_time": "2017-09-17T16:33:04.453351Z", "start_time": "2017-09-17T16:33:04.449675Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -622,8 +638,7 @@ "ExecuteTime": { "end_time": "2017-09-17T16:38:46.296544Z", "start_time": "2017-09-17T16:38:46.290670Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -663,8 +678,7 @@ "ExecuteTime": { "end_time": "2017-09-17T16:38:48.300312Z", "start_time": "2017-09-17T16:38:48.128590Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -746,9 +760,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "# define optimizer operation to minimize the loss\n", @@ -766,9 +778,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "## GRADED PART, DO NOT CHANGE!\n", @@ -781,9 +791,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "# you can make submission with answers so far to check yourself at this stage\n", @@ -805,8 +813,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:43:59.397913Z", "start_time": "2017-09-17T14:43:58.913391Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -821,8 +828,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:43:59.529548Z", "start_time": "2017-09-17T14:43:59.399567Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -860,8 +866,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:44:00.437338Z", "start_time": "2017-09-17T14:44:00.434472Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -878,15 +883,13 @@ "ExecuteTime": { "end_time": "2017-09-17T14:44:01.497022Z", "start_time": "2017-09-17T14:44:00.962013Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ "# you can load trained weights here\n", - "# you can load \"weights_{epoch}\" and continue training\n", "# uncomment the next line if you need to load weights\n", - "# saver.restore(s, os.path.abspath(\"weights\"))" + "# saver.restore(s, get_checkpoint_path(epoch=4))" ] }, { @@ -904,7 +907,6 @@ "end_time": "2017-09-17T12:42:16.120494Z", "start_time": "2017-09-17T12:31:03.779162Z" }, - "collapsed": true, "scrolled": true }, "outputs": [], @@ -943,7 +945,7 @@ " print('Epoch: {}, train loss: {}, val loss: {}'.format(epoch, train_loss, val_loss))\n", "\n", " # save weights after finishing epoch\n", - " saver.save(s, os.path.abspath(\"weights_{}\".format(epoch)))\n", + " saver.save(s, get_checkpoint_path(epoch))\n", " \n", "print(\"Finished!\")" ] @@ -951,9 +953,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "## GRADED PART, DO NOT CHANGE!\n", @@ -965,9 +965,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "# you can make submission with answers so far to check yourself at this stage\n", @@ -981,8 +979,7 @@ "ExecuteTime": { "end_time": "2017-09-17T12:42:16.399349Z", "start_time": "2017-09-17T12:42:16.122158Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -1015,13 +1012,12 @@ "ExecuteTime": { "end_time": "2017-09-17T12:42:16.535481Z", "start_time": "2017-09-17T12:42:16.400830Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ - "# save graph weights to file!\n", - "saver.save(s, os.path.abspath(\"weights\"))" + "# save last graph weights to file!\n", + "saver.save(s, get_checkpoint_path())" ] }, { @@ -1047,15 +1043,14 @@ "ExecuteTime": { "end_time": "2017-09-17T14:44:22.546086Z", "start_time": "2017-09-17T14:44:16.029331Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ "class final_model:\n", " # CNN encoder\n", " encoder, preprocess_for_model = get_cnn_encoder()\n", - " saver.restore(s, os.path.abspath(\"weights\")) # keras applications corrupt our graph, so we restore trained weights\n", + " saver.restore(s, get_checkpoint_path()) # keras applications corrupt our graph, so we restore trained weights\n", " \n", " # containers for current lstm state\n", " lstm_c = tf.Variable(tf.zeros([1, LSTM_UNITS]), name=\"cell\")\n", @@ -1096,8 +1091,7 @@ "ExecuteTime": { "end_time": "2017-09-17T17:27:17.828681Z", "start_time": "2017-09-17T17:27:17.820029Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -1115,8 +1109,7 @@ "ExecuteTime": { "end_time": "2017-09-17T14:44:22.575410Z", "start_time": "2017-09-17T14:44:22.547785Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -1163,8 +1156,7 @@ "ExecuteTime": { "end_time": "2017-09-17T17:44:15.525786Z", "start_time": "2017-09-17T17:44:15.238979Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -1197,7 +1189,6 @@ "end_time": "2017-09-17T15:07:47.191185Z", "start_time": "2017-09-17T15:06:44.121069Z" }, - "collapsed": true, "scrolled": true }, "outputs": [], @@ -1222,8 +1213,7 @@ "ExecuteTime": { "end_time": "2017-09-17T17:42:56.055265Z", "start_time": "2017-09-17T17:42:54.242164Z" - }, - "collapsed": true + } }, "outputs": [], "source": [ @@ -1236,9 +1226,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "apply_model_to_image_raw_bytes(open(\"portal-cake-10.jpg\", \"rb\").read())" @@ -1274,9 +1262,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "### YOUR EXAMPLES HERE ###"