diff --git a/gpytorch/kernels/periodic_kernel.py b/gpytorch/kernels/periodic_kernel.py index 207ed88b7..54131c2b1 100644 --- a/gpytorch/kernels/periodic_kernel.py +++ b/gpytorch/kernels/periodic_kernel.py @@ -17,9 +17,9 @@ class PeriodicKernel(Kernel): .. math:: \begin{equation*} - k_{\text{Periodic}}(\mathbf{x_1}, \mathbf{x_2}) = \exp \left( + k_{\text{Periodic}}(\mathbf{x}, \mathbf{x'}) = \exp \left( -2 \sum_i - \frac{\sin ^2 \left( \frac{\pi}{p} (\mathbf{x_{1,i}} - \mathbf{x_{2,i}} ) \right)}{\lambda} + \frac{\sin ^2 \left( \frac{\pi}{p} ({x_{i}} - {x_{i}'} ) \right)}{\lambda} \right) \end{equation*} @@ -28,9 +28,8 @@ class PeriodicKernel(Kernel): * :math:`p` is the period length parameter. * :math:`\lambda` is a lengthscale parameter. - Equation is based on [David Mackay's Introduction to Gaussian Processes equation 47] - (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.81.1927&rep=rep1&type=pdf) - albeit without feature-specific lengthscales and period lengths. The exponential + Equation is based on `David Mackay's Introduction to Gaussian Processes equation 47`_ + (albeit without feature-specific lengthscales and period lengths). The exponential coefficient was changed and lengthscale is not squared to maintain backwards compatibility .. note:: @@ -42,30 +41,30 @@ class PeriodicKernel(Kernel): This kernel does not have an ARD lengthscale or period length option. - Args: - :attr:`batch_shape` (torch.Size, optional): - Set this if you want a separate lengthscale for each - batch of input data. It should be `b` if :attr:`x1` is a `b x n x d` tensor. Default: `torch.Size([])`. - :attr:`active_dims` (tuple of ints, optional): - Set this if you want to compute the covariance of only a few input dimensions. The ints - corresponds to the indices of the dimensions. Default: `None`. - :attr:`period_length_prior` (Prior, optional): - Set this if you want to apply a prior to the period length parameter. Default: `None`. - :attr:`lengthscale_prior` (Prior, optional): - Set this if you want to apply a prior to the lengthscale parameter. Default: `None`. - :attr:`lengthscale_constraint` (Constraint, optional): - Set this if you want to apply a constraint to the value of the lengthscale. Default: `Positive`. - :attr:`period_length_constraint` (Constraint, optional): - Set this if you want to apply a constraint to the value of the period length. Default: `Positive`. - :attr:`eps` (float): - The minimum value that the lengthscale/period length can take - (prevents divide by zero errors). Default: `1e-6`. - - Attributes: - :attr:`lengthscale` (Tensor): - The lengthscale parameter. Size = `*batch_shape x 1 x 1`. - :attr:`period_length` (Tensor): - The period length parameter. Size = `*batch_shape x 1 x 1`. + :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each + batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. + :type batch_shape: torch.Size, optional + :param active_dims: (Default: `None`) Set this if you want to + compute the covariance of only a few input dimensions. The ints + corresponds to the indices of the dimensions. + :type active_dims: Tuple(int) + :param period_length_prior: (Default: `None`) + Set this if you want to apply a prior to the period length parameter. + :type period_length_prior: ~gpytorch.priors.Prior, optional + :param period_length_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the period length parameter. + :type period_length_constraint: ~gpytorch.constraints.Interval, optional + :param lengthscale_prior: (Default: `None`) + Set this if you want to apply a prior to the lengthscale parameter. + :type lengthscale_prior: ~gpytorch.priors.Prior, optional + :param lengthscale_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the lengthscale parameter. + :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional + :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). + :type eps: float, optional + + :var torch.Tensor lengthscale: The lengthscale parameter. Size = `*batch_shape x 1 x 1`. + :var torch.Tensor period_length: The period length parameter. Size = `*batch_shape x 1 x 1`. Example: >>> x = torch.randn(10, 5) @@ -78,6 +77,9 @@ class PeriodicKernel(Kernel): >>> # Batch: different lengthscale for each batch >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel(batch_size=2)) >>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10) + + .. _David Mackay's Introduction to Gaussian Processes equation 47: + http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.81.1927&rep=rep1&type=pdf """ has_lengthscale = True @@ -122,10 +124,22 @@ def _set_period_length(self, value): self.initialize(raw_period_length=self.raw_period_length_constraint.inverse_transform(value)) def forward(self, x1, x2, diag=False, **params): - x1_ = x1.div(self.period_length).mul(math.pi) - x2_ = x2.div(self.period_length).mul(math.pi) - diff = x1_.unsqueeze(-2) - x2_.unsqueeze(-3) - res = diff.sin().pow(2).sum(dim=-1).div(self.lengthscale).mul(-2.0).exp_() + # Pop this argument so that we can manually sum over dimensions + last_dim_is_batch = params.pop("last_dim_is_batch", False) + # Get lengthscale + lengthscale = self.lengthscale + + x1_ = x1.div(self.period_length / math.pi) + x2_ = x2.div(self.period_length / math.pi) + diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True, **params) + + sin_sq = diff.sin().pow(2.0) + if last_dim_is_batch: + lengthscale = lengthscale[..., None, :, :] + else: + sin_sq = sin_sq.sum(dim=-3) if diag: - res = res.squeeze(0) + lengthscale = lengthscale.squeeze(-1) + + res = sin_sq.div(lengthscale).mul(-2.0).exp() return res diff --git a/test/kernels/test_periodic_kernel.py b/test/kernels/test_periodic_kernel.py index 71f1a449a..2155a819f 100644 --- a/test/kernels/test_periodic_kernel.py +++ b/test/kernels/test_periodic_kernel.py @@ -7,9 +7,16 @@ from gpytorch.kernels import PeriodicKernel from gpytorch.priors import NormalPrior +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase -class TestPeriodicKernel(unittest.TestCase): +class TestPeriodicKernel(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return PeriodicKernel(**kwargs) + + def create_kernel_ard(self, num_dims, **kwargs): + return PeriodicKernel(**kwargs) + def test_computes_periodic_function(self): a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) b = torch.tensor([0, 2], dtype=torch.float).view(2, 1)