diff --git a/notebooks/experimental/empirical_ntk_resnet_tf.ipynb b/notebooks/experimental/empirical_ntk_resnet_tf.ipynb index f6e58aea..e4ff31e0 100644 --- a/notebooks/experimental/empirical_ntk_resnet_tf.ipynb +++ b/notebooks/experimental/empirical_ntk_resnet_tf.ipynb @@ -53,20 +53,20 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Xx1xr9v5pyl0", - "outputId": "d6586a6f-f74b-4543-8da7-1c33d0a46af1" + "outputId": "2de00c45-710c-4115-8e04-3cf3777098dc" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "GPU 0: A100-SXM4-40GB (UUID: GPU-097b69f8-fbcd-6775-82c1-4cfd355907e9)\n" + "GPU 0: A100-SXM4-40GB (UUID: GPU-00d2130b-454e-7677-9b34-bbe78525d972)\n" ] } ], @@ -79,18 +79,48 @@ "source": [ "# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`\n", "!pip install -q --upgrade pip\n", - "!pip install --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", - "!pip install -q git+https://www.github.com/google/neural-tangents.git" + "!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "\n", + "# TODO(romann): figure out why Colab crashes sometimes if TF is upgraded.\n", + "!pip install -q git+https://www.github.com/deepmind/tf2jax.git --no-deps\n", + "!pip install -q frozendict typing-extensions\n", + "!pip install -q git+https://www.github.com/google/neural-tangents.git --no-deps" ], "metadata": { - "id": "ZlckZChsxTWj" + "id": "ZlckZChsxTWj", + "outputId": "b65c517a-e6a1-4d28-df68-c47d36757a34", + "colab": { + "base_uri": "https://localhost:8080/" + } }, - "execution_count": null, - "outputs": [] + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[K |████████████████████████████████| 2.1 MB 8.7 MB/s \n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m951.0/951.0 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m175.7/175.7 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for tf2jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.0/99.0 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for neural-tangents (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "LbW8KVnsPfVd" }, @@ -108,7 +138,7 @@ "metadata": { "id": "CqxnhMKDE2Gf" }, - "execution_count": 13, + "execution_count": 4, "outputs": [] }, { @@ -129,7 +159,7 @@ "metadata": { "id": "wdqBa3OQDf97" }, - "execution_count": 14, + "execution_count": 5, "outputs": [] }, { @@ -143,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": { "id": "lPh5LGz9JBK_" }, @@ -194,13 +224,13 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "metadata": { "id": "bwhIZWqxKTlt", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "9914d19f-5ccc-46f1-c665-1674d2e11fea" + "outputId": "3befafda-8780-4fdc-f705-62d3b2925e77" }, "outputs": [ { @@ -225,13 +255,13 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "zObT8WnPggFo", - "outputId": "f2d97d14-961c-4b28-f457-28a7e541b1a7" + "outputId": "399e552e-1f94-4b80-cc2a-f43bc7acde2d" }, "outputs": [ { @@ -264,13 +294,13 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FW9gJJ4qggFp", - "outputId": "9e1c94f4-ac6b-44b4-be90-dcab39d69a23" + "outputId": "0058e7b6-68a1-4daa-f0c4-e48d260874f4" }, "outputs": [ { @@ -303,13 +333,13 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gFeWnqGQggFp", - "outputId": "256886ad-a7fa-4476-9cfb-b97c9d98b0b9" + "outputId": "bdf0f87a-e768-4e12-d23a-d377a14894d1" }, "outputs": [ { @@ -342,20 +372,20 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Q63v1L1aggFp", - "outputId": "c180cf00-3895-4034-cf03-91425436e9b4" + "outputId": "776b99af-f327-48a8-8ab8-63c62f6fc017" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "tf.Tensor(0.00024576858, shape=(), dtype=float32) tf.Tensor(0.0006717233, shape=(), dtype=float32) tf.Tensor(0.00086082943, shape=(), dtype=float32)\n" + "tf.Tensor(0.00024811307, shape=(), dtype=float32) tf.Tensor(0.0005964738, shape=(), dtype=float32) tf.Tensor(0.000839914, shape=(), dtype=float32)\n" ] } ], @@ -383,9 +413,9 @@ "base_uri": "https://localhost:8080/" }, "id": "gMsFPSOr3DTE", - "outputId": "53434e57-cece-4446-8306-688eb92ea03a" + "outputId": "2f052256-5dfa-4117-9d9a-b6fc38142ccf" }, - "execution_count": 21, + "execution_count": 12, "outputs": [ { "output_type": "stream", @@ -415,20 +445,20 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "diP7nkBuggFp", - "outputId": "6c0f222b-994c-4836-c24c-dc4dbab0c912" + "outputId": "e7489fdf-9807-4be9-d1ea-5cbfe76b6919" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "1 loop, best of 5: 317 ms per loop\n" + "1 loop, best of 5: 327 ms per loop\n" ] } ], @@ -439,20 +469,20 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wehCdvi2ggFp", - "outputId": "7482da28-ed14-4366-989b-8e5da43507ed" + "outputId": "6ebde0d2-e181-4efb-cb93-09bea9cf9d2b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "1 loop, best of 5: 484 ms per loop\n" + "1 loop, best of 5: 491 ms per loop\n" ] } ], @@ -465,20 +495,20 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Yrm53akVggFp", - "outputId": "f64a3915-5972-46fb-c857-ffe6d6dedbdf" + "outputId": "ad3293d9-5e6a-475d-90e7-0166c9565013" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "1 loop, best of 5: 184 ms per loop\n" + "1 loop, best of 5: 199 ms per loop\n" ] } ], @@ -490,20 +520,20 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 16, "metadata": { "id": "P1QtkBqLggFp", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "5d9eb5b5-7c8b-4dbe-a50f-2af774529493" + "outputId": "0dd0f5cd-6d36-4348-c3d1-8f4c0af5e8b3" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "10 loops, best of 5: 324 ms per loop\n" + "1 loop, best of 5: 326 ms per loop\n" ] } ], @@ -534,13 +564,13 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 17, "metadata": { "id": "3oHBaAmDhBON", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "e6253afa-8f09-404b-85d8-da7898055960" + "outputId": "f4ce0626-0ec8-4431-ea76-a78e00a9bf17" }, "outputs": [ { @@ -565,13 +595,13 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "sc1bUvL-KrK9", - "outputId": "237c9718-dd5d-45d2-fcfa-e621ee5472f9" + "outputId": "9d9543e4-6378-4bd4-c3e1-ebb1b077b106" }, "outputs": [ { @@ -604,13 +634,13 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qdNsmnjOKyp0", - "outputId": "d47dfbec-818f-4868-88e6-1ed51046e3ea" + "outputId": "5d72b27a-60c1-478b-acf2-fa93c8e57c5a" }, "outputs": [ { @@ -643,13 +673,13 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Iw6HL260K26E", - "outputId": "c69250eb-5d4c-4211-89c5-956241560740" + "outputId": "d46d21ba-b919-45d5-dc53-d2fdb0d8e36c" }, "outputs": [ { @@ -682,20 +712,20 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DYG0fV9nOjnd", - "outputId": "e15ad99f-9396-4570-831b-043a0de87dfe" + "outputId": "46825172-1a01-4700-9ff6-00044054b818" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "tf.Tensor(0.014025839, shape=(), dtype=float32) tf.Tensor(0.0051410757, shape=(), dtype=float32) tf.Tensor(0.0148538705, shape=(), dtype=float32)\n" + "tf.Tensor(0.016062234, shape=(), dtype=float32) tf.Tensor(0.002566749, shape=(), dtype=float32) tf.Tensor(0.014049936, shape=(), dtype=float32)\n" ] } ], @@ -710,13 +740,13 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 22, "metadata": { "id": "DyF4M5_HK5Fk", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "00648984-bbb2-4086-9610-38f521360c96" + "outputId": "8a6c727b-01cf-47e3-dd42-fcdafbeec34a" }, "outputs": [ { @@ -740,14 +770,14 @@ "output_type": "stream", "name": "stdout", "text": [ - "WARNING:tensorflow:5 out of the last 84 calls to .converted_fun at 0x7fb3006c1d40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" + "WARNING:tensorflow:5 out of the last 29 calls to .converted_fun at 0x7f2fde140dd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ - "WARNING:tensorflow:5 out of the last 84 calls to .converted_fun at 0x7fb3006c1d40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" + "WARNING:tensorflow:5 out of the last 29 calls to .converted_fun at 0x7f2fde140dd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { @@ -769,20 +799,20 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8g1IO71LLJlG", - "outputId": "50548640-2e14-4a8a-e54c-a6e16d467427" + "outputId": "be843ed9-ecc6-49d0-dc9a-2e78ddc98846" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "1 loop, best of 5: 495 ms per loop\n" + "1 loop, best of 5: 494 ms per loop\n" ] } ], @@ -793,20 +823,20 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KEIPcXYRMys2", - "outputId": "9ac83c1b-3c2b-47a8-80ce-43c80ede94ab" + "outputId": "7c26d989-4e1e-4267-d11f-550f0f89f519" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "1 loop, best of 5: 276 ms per loop\n" + "1 loop, best of 5: 275 ms per loop\n" ] } ], @@ -818,20 +848,20 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gVyUPA8xM1ot", - "outputId": "b82ec2d1-6243-486e-b62d-44b74b14179b" + "outputId": "1b2ac344-8c75-43c0-df91-5e57d4adf895" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "1 loop, best of 5: 232 ms per loop\n" + "1 loop, best of 5: 235 ms per loop\n" ] } ], @@ -843,20 +873,20 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 26, "metadata": { "id": "o81r3UHJM34_", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "31b346e4-0140-455c-989a-4b63a4709e9b" + "outputId": "44ebbabd-39aa-442c-bdaa-c0eea3b0452d" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "1 loop, best of 5: 275 ms per loop\n" + "1 loop, best of 5: 277 ms per loop\n" ] } ], @@ -887,13 +917,13 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 27, "metadata": { "id": "Z4MpdSJ7hFsD", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "9e47f10b-2dd9-48bb-a4de-baaf06468df9" + "outputId": "bc88b6f0-0f65-4ef0-9fb7-83a89a6b7e2f" }, "outputs": [ { @@ -918,13 +948,13 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ynBEtkA7d_2G", - "outputId": "5ef3642f-81da-43c6-a0f5-0488077f9753" + "outputId": "a0919de1-a73c-4724-f944-229d6ae96c6f" }, "outputs": [ { @@ -957,13 +987,13 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Afniv9k2d_2H", - "outputId": "16363413-13f0-488b-b4a9-07305f8b8a93" + "outputId": "ff5dca85-e0c8-4280-f614-97621c0a4553" }, "outputs": [ { @@ -981,14 +1011,14 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 433 }, "id": "4Ou5a6y4zH5Q", - "outputId": "c7297d33-c265-4c79-b889-a5ffc58b721c" + "outputId": "39606633-c647-48ae-e0b6-fbd5d126c655" }, "outputs": [ { @@ -1012,29 +1042,29 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mUnknownError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# NTK-vector products - OOM!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mk_3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mntk_fn_ntvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk_3\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# NTK-vector products - OOM!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mk_3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mntk_fn_ntvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk_3\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 153\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 154\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\" name: \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m keras_symbolic_tensors = [\n", - "\u001b[0;31mUnknownError\u001b[0m: Failed to determine best cudnn convolution algorithm: RESOURCE_EXHAUSTED: Allocating 9437184000 bytes exceeds the memory limit of 4294967296 bytes.\n\nConvolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. [Op:__inference_converted_fun_1940145]" + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0;32m---> 59\u001b[0;31m inputs, attrs, num_outputs)\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mUnknownError\u001b[0m: Failed to determine best cudnn convolution algorithm: RESOURCE_EXHAUSTED: Allocating 9437184000 bytes exceeds the memory limit of 4294967296 bytes.\n\nConvolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. [Op:__inference_converted_fun_1922490]" ] } ], "source": [ "# NTK-vector products - OOM!\n", - "k_3 = ntk_fn_ntvp(x1, x2, params)\n", + "k_2 = ntk_fn_ntvp(x1, x2, params)\n", "print(k_3.shape)" ] }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 433 }, "id": "rjXRalQfd_2G", - "outputId": "6357f07b-e1ef-4079-932b-bf31f7a8b046" + "outputId": "3616ffb4-365f-431b-f95c-fd821279c815" }, "outputs": [ { @@ -1058,10 +1088,10 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mUnknownError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Jacobian contraction - OOM!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mk_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mntk_fn_jacobian_contraction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk_1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Jacobian contraction - OOM!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mk_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mntk_fn_jacobian_contraction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk_1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 153\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 154\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\" name: \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m keras_symbolic_tensors = [\n", - "\u001b[0;31mUnknownError\u001b[0m: Failed to determine best cudnn convolution algorithm: RESOURCE_EXHAUSTED: Allocating 9437184000 bytes exceeds the memory limit of 4294967296 bytes.\n\nConvolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. [Op:__inference_converted_fun_2069598]" + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0;32m---> 59\u001b[0;31m inputs, attrs, num_outputs)\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mUnknownError\u001b[0m: Failed to determine best cudnn convolution algorithm: RESOURCE_EXHAUSTED: Allocating 9437184000 bytes exceeds the memory limit of 4294967296 bytes.\n\nConvolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. [Op:__inference_converted_fun_2051943]" ] } ],