From 64b9ae8bf0a430795e437aac4f03074c6f88920e Mon Sep 17 00:00:00 2001 From: Abhinav Rao Date: Sun, 3 Dec 2023 00:16:02 -0500 Subject: [PATCH] hw-6 corrected kernel_func issue --- .../homework-06-checkpoint.ipynb | 9884 +++++++++++++++++ lecturebook/homework/homework-06.ipynb | 4 +- 2 files changed, 9886 insertions(+), 2 deletions(-) create mode 100644 lecturebook/homework/.ipynb_checkpoints/homework-06-checkpoint.ipynb diff --git a/lecturebook/homework/.ipynb_checkpoints/homework-06-checkpoint.ipynb b/lecturebook/homework/.ipynb_checkpoints/homework-06-checkpoint.ipynb new file mode 100644 index 00000000..164bf707 --- /dev/null +++ b/lecturebook/homework/.ipynb_checkpoints/homework-06-checkpoint.ipynb @@ -0,0 +1,9884 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Homework 6\n", + "\n", + "## References\n", + "\n", + "+ Lectures 21-23 (inclusive).\n", + "\n", + "\n", + "## Instructions\n", + "\n", + "+ Type your name and email in the \"Student details\" section below.\n", + "+ Develop the code and generate the figures you need to solve the problems using this notebook.\n", + "+ For the answers that require a mathematical proof or derivation you should type them using latex. If you have never written latex before and you find it exceedingly difficult, we will likely accept handwritten solutions.\n", + "+ The total homework points are 100. Please note that the problems are not weighed equally." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If on Google Colab, install the following packages:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install gpytorch" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "import matplotlib_inline\n", + "matplotlib_inline.backend_inline.set_matplotlib_formats('svg')\n", + "import seaborn as sns\n", + "sns.set_context(\"paper\")\n", + "sns.set_style(\"ticks\")\n", + "\n", + "import scipy\n", + "import scipy.stats as st\n", + "import urllib.request\n", + "import os\n", + "\n", + "def download(\n", + " url : str,\n", + " local_filename : str = None\n", + "):\n", + " \"\"\"Download a file from a url.\n", + " \n", + " Arguments\n", + " url -- The url we want to download.\n", + " local_filename -- The filemame to write on. If not\n", + " specified \n", + " \"\"\"\n", + " if local_filename is None:\n", + " local_filename = os.path.basename(url)\n", + " urllib.request.urlretrieve(url, local_filename)\n", + "\n", + "def sample_functions(mean_func, kernel_func, num_samples=10, num_test=100, nugget=1e-3):\n", + " \"\"\"Sample functions from a Gaussian process.\n", + "\n", + " Arguments:\n", + " mean_func -- the mean function. It must be a callable that takes a tensor\n", + " of shape (num_test, dim) and returns a tensor of shape (num_test, 1).\n", + " kernel_func -- the covariance function. It must be a callable that takes\n", + " a tensor of shape (num_test, dim) and returns a tensor of shape\n", + " (num_test, num_test).\n", + " num_samples -- the number of samples to take. Defaults to 10.\n", + " num_test -- the number of test points. Defaults to 100.\n", + " nugget -- a small number required for stability. Defaults to 1e-5.\n", + " \"\"\"\n", + " X = torch.linspace(0, 1, num_test)[:, None]\n", + " m = mean_func(X)\n", + " C = kernel_func.forward(X, X) + nugget * torch.eye(X.shape[0])\n", + " L = torch.linalg.cholesky(C)\n", + " fig, ax = plt.subplots()\n", + " ax.plot(X, m.detach(), label='mean')\n", + " for i in range(num_samples):\n", + " z = torch.randn(X.shape[0], 1) \n", + " f = m[:, None] + L @ z \n", + " ax.plot(X.flatten(), f.detach().flatten(), color=sns.color_palette()[1], linewidth=0.5, \n", + " label='sample' if i == 0 else None\n", + " )\n", + " plt.legend(loc='best', frameon=False)\n", + " ax.set_xlabel('$x$')\n", + " ax.set_ylabel('$y$')\n", + " ax.set_ylim(-5, 5)\n", + " sns.despine(trim=True);\n", + "\n", + "\n", + "import gpytorch\n", + "\n", + "class ExactGP(gpytorch.models.ExactGP):\n", + " def __init__(self,\n", + " train_x,\n", + " train_y,\n", + " likelihood=gpytorch.likelihoods.GaussianLikelihood(),\n", + " mean_module=gpytorch.means.ConstantMean(),\n", + " covar_module=ScaleKernel(RBFKernel())\n", + " ):\n", + " super().__init__(train_x, train_y, likelihood)\n", + " self.mean_module = mean_module\n", + " self.covar_module = covar_module\n", + "\n", + " def forward(self, x):\n", + " mean_x = self.mean_module(x)\n", + " covar_x = self.covar_module(x)\n", + " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", + "\n", + "\n", + "def plot_1d_regression(\n", + " x_star,\n", + " model,\n", + " ax=None,\n", + " f_true=None,\n", + " num_samples=10,\n", + " xlabel='$x$',\n", + " ylabel='$y$'\n", + "):\n", + " \"\"\"Plot the posterior predictive.\n", + " \n", + " Arguments\n", + " x_start -- The test points on which to evaluate.\n", + " model -- The trained model.\n", + " \n", + " Keyword Arguments\n", + " ax -- An axes object to write on.\n", + " f_true -- The true function.\n", + " num_samples -- The number of samples.\n", + " xlabel -- The x-axis label.\n", + " ylabel -- The y-axis label.\n", + " \"\"\"\n", + " f_star = model(x_star)\n", + " m_star = f_star.mean\n", + " v_star = f_star.variance\n", + " y_star = model.likelihood(f_star)\n", + " yv_star = y_star.variance\n", + "\n", + " f_lower = (\n", + " m_star - 2.0 * torch.sqrt(v_star)\n", + " )\n", + " f_upper = (\n", + " m_star + 2.0 * torch.sqrt(v_star)\n", + " )\n", + " \n", + " y_lower = m_star - 2.0 * torch.sqrt(yv_star)\n", + " y_upper = m_star + 2.0 * torch.sqrt(yv_star)\n", + "\n", + " if ax is None:\n", + " fig, ax = plt.subplots()\n", + " \n", + " ax.plot(model.train_inputs[0].flatten().detach(),\n", + " model.train_targets.detach(),\n", + " 'k.',\n", + " markersize=1,\n", + " markeredgewidth=2,\n", + " label='Observations'\n", + " )\n", + "\n", + " ax.plot(\n", + " x_star,\n", + " m_star.detach(),\n", + " lw=2,\n", + " label='Posterior mean',\n", + " color=sns.color_palette()[0]\n", + " )\n", + " \n", + " ax.fill_between(\n", + " x_star.flatten().detach(),\n", + " f_lower.flatten().detach(),\n", + " f_upper.flatten().detach(),\n", + " alpha=0.5,\n", + " label='Epistemic uncertainty',\n", + " color=sns.color_palette()[0]\n", + " )\n", + "\n", + " ax.fill_between(\n", + " x_star.detach().flatten(),\n", + " y_lower.detach().flatten(),\n", + " f_lower.detach().flatten(),\n", + " color=sns.color_palette()[1],\n", + " alpha=0.5,\n", + " label='Aleatory uncertainty'\n", + " )\n", + " ax.fill_between(\n", + " x_star.detach().flatten(),\n", + " f_upper.detach().flatten(),\n", + " y_upper.detach().flatten(),\n", + " color=sns.color_palette()[1],\n", + " alpha=0.5,\n", + " label=None\n", + " )\n", + "\n", + " \n", + " if f_true is not None:\n", + " ax.plot(\n", + " x_star,\n", + " f_true(x_star),\n", + " 'm-.',\n", + " label='True function'\n", + " )\n", + " \n", + " if num_samples > 0:\n", + " f_post_samples = f_star.sample(\n", + " sample_shape=torch.Size([10])\n", + " )\n", + " ax.plot(\n", + " x_star.numpy(),\n", + " f_post_samples.T.detach().numpy(),\n", + " color=\"red\",\n", + " lw=0.5\n", + " )\n", + " # This is just to add the legend entry\n", + " ax.plot(\n", + " [],\n", + " [],\n", + " color=\"red\",\n", + " lw=0.5,\n", + " label=\"Posterior samples\"\n", + " )\n", + " \n", + " ax.set_xlabel(xlabel)\n", + " ax.set_ylabel(ylabel)\n", + "\n", + " plt.legend(loc='best', frameon=False)\n", + " sns.despine(trim=True)\n", + " \n", + " return dict(m_star=m_star, v_star=v_star, ax=ax)\n", + "\n", + "\n", + "def train(model, train_x, train_y, n_iter=10, lr=0.1):\n", + " \"\"\"Train the model.\n", + "\n", + " Arguments\n", + " model -- The model to train.\n", + " train_x -- The training inputs.\n", + " train_y -- The training labels.\n", + " n_iter -- The number of iterations.\n", + " \"\"\"\n", + " model.train()\n", + " optimizer = torch.optim.LBFGS(model.parameters(), line_search_fn='strong_wolfe')\n", + " likelihood = model.likelihood\n", + " mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n", + " def closure():\n", + " optimizer.zero_grad()\n", + " output = model(train_x)\n", + " loss = -mll(output, train_y)\n", + " loss.backward()\n", + " print(loss)\n", + " return loss\n", + " for i in range(n_iter):\n", + " loss = optimizer.step(closure)\n", + " if (i + 1) % 1 == 0:\n", + " print(f'Iter {i + 1:3d}/{n_iter} - Loss: {loss.item():.3f}')\n", + " model.eval()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Student details\n", + "\n", + "+ **First Name:**\n", + "+ **Last Name:**\n", + "+ **Email:**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Problem 1 - Defining priors on function spaces\n", + "\n", + "In this problem, we will explore further how Gaussian processes can be used to define probability measures over function spaces.\n", + "To this end, assume that there is a 1D function, call if $f(x)$, which we do not know.\n", + "For simplicity, assume that $x$ takes values in $[0,1]$.\n", + "We will employ Gaussian process regression to encode our state of knowledge about $f(x)$ and sample some possibilities.\n", + "For each of the cases below:\n", + "+ Assume that $f\\sim \\operatorname{GP}(m, k)$ and pick a mean ($m(x)$) and a covariance function $f(x)$ that match the provided information.\n", + "+ Write code that samples a few times (up to five) the values of $f(x)$ at 100 equidistant points between 0 and 1.\n", + "\n", + "### Part A - Super smooth function with known length scale\n", + "\n", + "Assume that you hold the following beliefs\n", + "+ You know that $f(x)$ has as many derivatives as you want and they are all continuous\n", + "+ You don't know if $f(x)$ has a specific trend.\n", + "+ You think that $f(x)$ has \"wiggles\" that are approximatly of size $\\Delta x=0.1$.\n", + "+ You think that $f(x)$ is between -4 and 4.\n", + "\n", + "**Answer:**\n", + "\n", + "**I am doing this for you so that you have a concrete example of what is requested.**\n", + "\n", + "The mean function should be:\n", + "\n", + "$$\n", + "m(x) = 0.\n", + "$$\n", + "\n", + "The covariance function should be a squared exponential:\n", + "\n", + "$$\n", + "k(x,x') = s^2\\exp\\left\\{-\\frac{(x-x')^2}{2\\ell^2}\\right\\},\n", + "$$\n", + "\n", + "with variance:\n", + "\n", + "$$\n", + "s^2 = k(x,x) = \\mathbb{V}[f(x)] = 4,\n", + "$$\n", + "\n", + "and lengthscale $\\ell = 0.1$.\n", + "We chose the variance to be 4.0 so that with (about) 95% probability, the values of $f(x)$ are between -4 and 4." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-23T22:32:47.091636\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "import gpytorch\n", + "from gpytorch.kernels import RBFKernel, ScaleKernel\n", + "\n", + "# Define the covariance function\n", + "k = ScaleKernel(RBFKernel())\n", + "k.outputscale = 4.0\n", + "k.base_kernel.lengthscale = 0.1\n", + "\n", + "# Define the mean function\n", + "mean = gpytorch.means.ConstantMean()\n", + "mean.constant = 0.0\n", + "\n", + "# Sample functions\n", + "sample_functions(mean, k, nugget=1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part B - Super smooth function with known ultra-small length scale\n", + "\n", + "Assume that you hold the following beliefs\n", + "+ You know that $f(x)$ has as many derivatives as you want and they are all continuous\n", + "+ You don't know if $f(x)$ has a specific trend.\n", + "+ You think that $f(x)$ has \"wiggles\" that are approximatly of size $\\Delta x=0.05$.\n", + "+ You think that $f(x)$ is between -3 and 3.\n", + "\n", + "**Answer:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your code here" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part C - Continuous function with known length scale\n", + "\n", + "Assume that you hold the following beliefs\n", + "+ You know that $f(x)$ is continuous, nowhere differentiable.\n", + "+ You don't know if $f(x)$ has a specific trend.\n", + "+ You think that $f(x)$ has \"wiggles\" that are approximately of size $\\Delta x=0.1$.\n", + "+ You think that $f(x)$ is between -5 and 5.\n", + "\n", + "Hint: Use ``gpytorch.kernels.MaternKernel`` with $\\nu=1/2$.\n", + "\n", + "**Answer:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your code here" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part D - Smooth periodic function with known length scale\n", + "\n", + "Assume that you hold the following beliefs\n", + "+ You know that $f(x)$ is smooth.\n", + "+ You know that $f(x)$ is periodic with period 0.1.\n", + "+ You don't know if $f(x)$ has a specific trend.\n", + "+ You think that $f(x)$ has \"wiggles\" that are approximately of size $\\Delta x=0.5$ of the period.\n", + "+ You think that $f(x)$ is between -5 and 5.\n", + "\n", + "Hint: Use ``gpytorch.kernels.PeriodicKernel``.\n", + "\n", + "**Answer:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your code here" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part E - Smooth periodic function with known length scale\n", + "\n", + "Assume that you hold the following beliefs\n", + "+ You know that $f(x)$ is smooth.\n", + "+ You know that $f(x)$ is periodic with period 0.1.\n", + "+ You don't know if $f(x)$ has a specific trend.\n", + "+ You think that $f(x)$ has \"wiggles\" that are approximately of size $\\Delta x=0.1$ of the period (**the only thing that is different compared to D**).\n", + "+ You think that $f(x)$ is between -5 and 5.\n", + "\n", + "Hint: Use ``gpytorch.kernels.PeriodicKernel``.\n", + "\n", + "\n", + "**Answer:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your code here" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part F - The sum of two functions\n", + "\n", + "Assume that you hold the following beliefs\n", + "+ You know that $f(x) = f_1(x) + f_2(x)$, where:\n", + " - $f_1(x)$ is smooth with variance 2 and length scale 0.5\n", + " - $f_2(x)$ is continuous, nowhere differentiable with variance 0.1 and length scale 0.1\n", + "\n", + "Hint: Use must create a new covariance function that is the sum of two other covariances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your code here" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part G - The product of two functions\n", + "\n", + "Assume that you hold the following beliefs\n", + "+ You know that $f(x) = f_1(x)f_2(x)$, where:\n", + " - $f_1(x)$ is smooth, periodic (period = 0.1), length scale 0.1 (relative to the period), and variance 2.\n", + " - $f_2(x)$ is smooth with length scale 0.5 and variance 1.\n", + "\n", + "Hint: Use must create a new covariance function that is the product of two other covariances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your code here" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Problem 2\n", + "\n", + "The National Oceanic and Atmospheric Administration (NOAA) has been measuring the levels of atmospheric CO2 at the Mauna Loa, Hawaii. The measurements start in March 1958 and go back to January 2016.\n", + "The data can be found [here](http://www.esrl.noaa.gov/gmd/ccgg/trends/data.html).\n", + "The Python cell below downloads and plots the data set." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://github.com/PredictiveScienceLab/data-analytics-se/raw/master/lecturebook/data/mauna_loa_co2.txt\"\n", + "download(url)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "data = np.loadtxt('mauna_loa_co2.txt')" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-23T22:52:31.874777\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "#load data \n", + "t = data[:, 2] #time (in decimal dates)\n", + "y = data[:, 4] #CO2 level (mole fraction in dry air, micromol/mol, abbreviated as ppm)\n", + "fig, ax = plt.subplots(1, 1)\n", + "ax.plot(t, y, '.', markersize=1)\n", + "ax.set_xlabel('$t$ (year)')\n", + "ax.set_ylabel('$y$ (CO2 level in ppm)')\n", + "sns.despine(trim=True);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Overall, we observe a steady growth of CO2 levels. The wiggles correspond to seasonal changes. Since most of the population inhabits the northern hemisphere, fuel consumption increases during the northern winters, and CO2 emissions follow. Our goal is to study this dataset with Gaussian process regression. Specifically, we would like to predict the evolution of the CO2 levels from Feb 2018 to Feb 2028 and quantify our uncertainty about this prediction.\n", + "\n", + "Working with a scaled version of the inputs and outputs is always a good idea. We are going to scale the times as follows:\n", + "\n", + "$$\n", + "t_s = t - t_{\\min}.\n", + "$$\n", + "\n", + "So, time is still in fractional years, but we start counting at zero instead of 1950.\n", + "We scale the $y$'s as:\n", + "\n", + "$$\n", + "y_s = \\frac{y - y_{\\min}}{y_{\\max}-y_{\\min}}.\n", + "$$\n", + "\n", + "This takes all the $y$ between $0$ and $1$.\n", + "Here is what the scaled data look like:" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-23T22:52:35.548519\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "t_s = t - t.min()\n", + "y_s = (y - y.min()) / (y.max() - y.min())\n", + "fig, ax = plt.subplots(1, 1)\n", + "ax.plot(t_s, y_s, '.', markersize=1)\n", + "ax.set_xlabel('$t_s$ (Scaled year)')\n", + "ax.set_ylabel('$y_s$ (Scaled CO2 level)')\n", + "sns.despine(trim=True);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Work with the scaled data in what follows as you develop your model.\n", + "Scale back to the original units for your final predictions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part A - Naive approach\n", + "\n", + "Use a zero mean Gaussian process with a squared exponential covariance function to fit the data and make the required prediction (ten years after the last observation).\n", + "\n", + "**Answer:**\n", + "\n", + "**Again, this is done for you so that you have a concrete example of what is requested.**" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0.8545, grad_fn=)\n", + "tensor(0.7392, grad_fn=)\n", + "tensor(-0.5164, grad_fn=)\n", + "tensor(-1.7390, grad_fn=)\n", + "tensor(-2.1109, grad_fn=)\n", + "tensor(-2.2523, grad_fn=)\n", + "tensor(-2.0013, grad_fn=)\n", + "tensor(-2.2894, grad_fn=)\n", + "tensor(-2.3039, grad_fn=)\n", + "tensor(-2.3159, grad_fn=)\n", + "tensor(-2.3302, grad_fn=)\n", + "tensor(-2.3335, grad_fn=)\n", + "tensor(-2.2837, grad_fn=)\n", + "tensor(-2.3380, grad_fn=)\n", + "tensor(-2.3401, grad_fn=)\n", + "tensor(-2.3443, grad_fn=)\n", + "tensor(-2.3464, grad_fn=)\n", + "tensor(-2.3477, grad_fn=)\n", + "tensor(-2.3481, grad_fn=)\n", + "tensor(-2.3505, grad_fn=)\n", + "tensor(-2.3518, grad_fn=)\n", + "tensor(-2.3526, grad_fn=)\n", + "tensor(-2.3527, grad_fn=)\n", + "tensor(-2.3529, grad_fn=)\n", + "tensor(-2.3531, grad_fn=)\n", + "Iter 1/10 - Loss: 0.854\n", + "tensor(-2.3531, grad_fn=)\n", + "tensor(-2.3537, grad_fn=)\n", + "tensor(-2.3538, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 2/10 - Loss: -2.353\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 3/10 - Loss: -2.354\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 4/10 - Loss: -2.354\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 5/10 - Loss: -2.354\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 6/10 - Loss: -2.354\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 7/10 - Loss: -2.354\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 8/10 - Loss: -2.354\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 9/10 - Loss: -2.354\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3541, grad_fn=)\n", + "tensor(-2.3542, grad_fn=)\n", + "tensor(-2.3540, grad_fn=)\n", + "tensor(-2.3539, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "tensor(-2.3543, grad_fn=)\n", + "Iter 10/10 - Loss: -2.354\n" + ] + } + ], + "source": [ + "cov_module = ScaleKernel(RBFKernel())\n", + "mean_module = gpytorch.means.ConstantMean()\n", + "train_x = torch.from_numpy(t_s).float()\n", + "train_y = torch.from_numpy(y_s).float()\n", + "naive_model = ExactGP(\n", + " train_x,\n", + " train_y,\n", + " mean_module=mean_module,\n", + " covar_module=cov_module\n", + ")\n", + "train(naive_model, train_x, train_y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Predict everything:" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-23T23:09:18.265247\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x_star = torch.linspace(0, 100, 100)\n", + "plot_1d_regression(model=naive_model, x_star=x_star, \n", + " xlabel='$t_s$ (Scaled year)', ylabel='$y_s$ (Scaled CO2 level)');" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the squared exponential covariance captures the long terms but fails to capture the seasonal fluctuations. The seasonal fluctuations are treated as noise. This is wrong. You will have to fix this in the next part." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part B - Improving the prior covariance\n", + "\n", + "Now, use the ideas of Problem 1 to develop a covariance function that exhibits the following characteristics visible in the data (call $f(x)$ the scaled CO2 level.\n", + "+ $f(x)$ is smooth.\n", + "+ $f(x)$ has a clear trend with a multi-year length scale.\n", + "+ $f(x)$ has seasonal fluctuations with a period of one year.\n", + "+ $f(x)$ exhibits small fluctuations within its period.\n", + "\n", + "There is more than one correct answer.\n", + "\n", + "**Answer:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cov_module = # Your choice of covariance here\n", + "mean_module = # Your choice of mean here\n", + "model = ExactGP(\n", + " train_x,\n", + " train_y,\n", + " mean_module=mean_module,\n", + " covar_module=cov_module\n", + ")\n", + "train(model, train_x, train_y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot using the following block:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_1d_regression(model=naive_model, x_star=train_x);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part C - Predicting the future\n", + "\n", + "How does your model predict the future? Why is it better than the naive model?\n", + "\n", + "**Answer:**\n", + "*Your answer here*\n", + "

" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part D - Bayesian information criterion\n", + "\n", + "As we have seen in earlier lectures, the Bayesian informationc criterion (BIC), see [this](https://en.wikipedia.org/wiki/Bayesian_information_criterion), can bse used to compare two models.\n", + "The criterion says that one should:\n", + "+ fit the models with maximum likelihood,\n", + "+ and compute the quantity:\n", + "\n", + "$$\n", + "\\text{BIC} = d\\ln(n) - 2\\ln(\\hat{L}),\n", + "$$\n", + "\n", + "where $d$ is the number of model parameters, and $\\hat{L}$ the maximum likelihood.\n", + "+ pick the model with the smallest BIC.\n", + "\n", + "Use BIC to show that the model you constructed in Part C is indeed better than the naïve model of Part A.\n", + "\n", + "**Answer:**" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Parameter containing:\n", + " tensor([-7.8281], requires_grad=True),\n", + " Parameter containing:\n", + " tensor(0.8690, requires_grad=True),\n", + " Parameter containing:\n", + " tensor(-1.2034, requires_grad=True),\n", + " Parameter containing:\n", + " tensor([[32.5616]], requires_grad=True)]" + ] + }, + "execution_count": 112, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Hint: You can find the parameters of a model like this\n", + "list(naive_model.hyperparameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4\n" + ] + } + ], + "source": [ + "m = sum(p.numel() for p in naive_model.hyperparameters())\n", + "print(m)" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(2.3863, grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:284: GPInputWarning: The input matches the stored training data. Did you forget to call model.train()?\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "# Hint: You can find the (marginal) log likelihood of a model like this\n", + "mll = gpytorch.mlls.ExactMarginalLogLikelihood(naive_model.likelihood, naive_model)\n", + "log_like = mll(naive_model(train_x), train_y)\n", + "print(log_like)" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(21.5389, grad_fn=)\n" + ] + } + ], + "source": [ + "# Hint: The BIC is\n", + "bic = -2 * log_like + m * np.log(train_x.shape[0])\n", + "print(bic)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your code here" + ] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/lecturebook/homework/homework-06.ipynb b/lecturebook/homework/homework-06.ipynb index c55d62fb..164bf707 100644 --- a/lecturebook/homework/homework-06.ipynb +++ b/lecturebook/homework/homework-06.ipynb @@ -89,7 +89,7 @@ " \"\"\"\n", " X = torch.linspace(0, 1, num_test)[:, None]\n", " m = mean_func(X)\n", - " C = k.forward(X, X) + nugget * torch.eye(X.shape[0])\n", + " C = kernel_func.forward(X, X) + nugget * torch.eye(X.shape[0])\n", " L = torch.linalg.cholesky(C)\n", " fig, ax = plt.subplots()\n", " ax.plot(X, m.detach(), label='mean')\n", @@ -9876,7 +9876,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.5" } }, "nbformat": 4,