Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix bug with PeriodicKernel.diag() #1919

Merged
merged 4 commits into from
Apr 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 57 additions & 40 deletions gpytorch/kernels/periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class PeriodicKernel(Kernel):
.. math::

\begin{equation*}
k_{\text{Periodic}}(\mathbf{x_1}, \mathbf{x_2}) = \exp \left(
k_{\text{Periodic}}(\mathbf{x}, \mathbf{x'}) = \exp \left(
-2 \sum_i
\frac{\sin ^2 \left( \frac{\pi}{p} (\mathbf{x_{1,i}} - \mathbf{x_{2,i}} ) \right)}{\lambda}
\frac{\sin ^2 \left( \frac{\pi}{p} ({x_{i}} - {x_{i}'} ) \right)}{\lambda}
\right)
\end{equation*}

Expand All @@ -28,44 +28,44 @@ class PeriodicKernel(Kernel):
* :math:`p` is the period length parameter.
* :math:`\lambda` is a lengthscale parameter.

Equation is based on [David Mackay's Introduction to Gaussian Processes equation 47]
(http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.81.1927&rep=rep1&type=pdf)
albeit without feature-specific lengthscales and period lengths. The exponential
Equation is based on `David Mackay's Introduction to Gaussian Processes equation 47`_
(albeit without feature-specific lengthscales and period lengths). The exponential
coefficient was changed and lengthscale is not squared to maintain backwards compatibility

.. note::

This kernel does not have an `outputscale` parameter. To add a scaling parameter,
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.

.. note::

This kernel does not have an ARD lengthscale or period length option.

Args:
:attr:`batch_shape` (torch.Size, optional):
Set this if you want a separate lengthscale for each
batch of input data. It should be `b` if :attr:`x1` is a `b x n x d` tensor. Default: `torch.Size([])`.
:attr:`active_dims` (tuple of ints, optional):
Set this if you want to compute the covariance of only a few input dimensions. The ints
corresponds to the indices of the dimensions. Default: `None`.
:attr:`period_length_prior` (Prior, optional):
Set this if you want to apply a prior to the period length parameter. Default: `None`.
:attr:`lengthscale_prior` (Prior, optional):
Set this if you want to apply a prior to the lengthscale parameter. Default: `None`.
:attr:`lengthscale_constraint` (Constraint, optional):
Set this if you want to apply a constraint to the value of the lengthscale. Default: `Positive`.
:attr:`period_length_constraint` (Constraint, optional):
Set this if you want to apply a constraint to the value of the period length. Default: `Positive`.
:attr:`eps` (float):
The minimum value that the lengthscale/period length can take
(prevents divide by zero errors). Default: `1e-6`.

Attributes:
:attr:`lengthscale` (Tensor):
The lengthscale parameter. Size = `*batch_shape x 1 x 1`.
:attr:`period_length` (Tensor):
The period length parameter. Size = `*batch_shape x 1 x 1`.
:param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
input dimension. It should be `d` if :attr:`x1` is a `... x n x d` matrix.
:type ard_num_dims: int, optional
:param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
:type batch_shape: torch.Size, optional
:param active_dims: (Default: `None`) Set this if you want to
compute the covariance of only a few input dimensions. The ints
corresponds to the indices of the dimensions.
:type active_dims: Tuple(int)
:param period_length_prior: (Default: `None`)
Set this if you want to apply a prior to the period length parameter.
:type period_length_prior: ~gpytorch.priors.Prior, optional
:param period_length_constraint: (Default: `Positive`) Set this if you want
to apply a constraint to the period length parameter.
:type period_length_constraint: ~gpytorch.constraints.Interval, optional
:param lengthscale_prior: (Default: `None`)
Set this if you want to apply a prior to the lengthscale parameter.
:type lengthscale_prior: ~gpytorch.priors.Prior, optional
:param lengthscale_constraint: (Default: `Positive`) Set this if you want
to apply a constraint to the lengthscale parameter.
:type lengthscale_constraint: ~gpytorch.constraints.Interval, optional
:param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
:type eps: float, optional

:var torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
:attr:`ard_num_dims` and :attr:`batch_shape` arguments.
:var torch.Tensor period_length: The period length parameter. Size/shape of parameter depends on the
:attr:`ard_num_dims` and :attr:`batch_shape` arguments.

Example:
>>> x = torch.randn(10, 5)
Expand All @@ -78,6 +78,9 @@ class PeriodicKernel(Kernel):
>>> # Batch: different lengthscale for each batch
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel(batch_size=2))
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)

.. _David Mackay's Introduction to Gaussian Processes equation 47:
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.81.1927&rep=rep1&type=pdf
"""

has_lengthscale = True
Expand All @@ -92,8 +95,9 @@ def __init__(
if period_length_constraint is None:
period_length_constraint = Positive()

ard_num_dims = kwargs.get("ard_num_dims", 1)
self.register_parameter(
name="raw_period_length", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1))
name="raw_period_length", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, ard_num_dims))
)

if period_length_prior is not None:
Expand Down Expand Up @@ -122,10 +126,23 @@ def _set_period_length(self, value):
self.initialize(raw_period_length=self.raw_period_length_constraint.inverse_transform(value))

def forward(self, x1, x2, diag=False, **params):
x1_ = x1.div(self.period_length).mul(math.pi)
x2_ = x2.div(self.period_length).mul(math.pi)
diff = x1_.unsqueeze(-2) - x2_.unsqueeze(-3)
res = diff.sin().pow(2).sum(dim=-1).div(self.lengthscale).mul(-2.0).exp_()
# Pop this argument so that we can manually sum over dimensions
last_dim_is_batch = params.pop("last_dim_is_batch", False)
# Get lengthscale
lengthscale = self.lengthscale

x1_ = x1.div(self.period_length / math.pi)
x2_ = x2.div(self.period_length / math.pi)
# We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions.
diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True, **params)

if diag:
res = res.squeeze(0)
return res
lengthscale = lengthscale[..., 0, :, None]
else:
lengthscale = lengthscale[..., 0, :, None, None]
exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0)

if not last_dim_is_batch:
exp_term = exp_term.sum(dim=(-2 if diag else -3))

return exp_term.exp()
9 changes: 8 additions & 1 deletion test/kernels/test_periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@

from gpytorch.kernels import PeriodicKernel
from gpytorch.priors import NormalPrior
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class TestPeriodicKernel(unittest.TestCase):
class TestPeriodicKernel(unittest.TestCase, BaseKernelTestCase):
def create_kernel_no_ard(self, **kwargs):
return PeriodicKernel(**kwargs)

def create_kernel_ard(self, num_dims, **kwargs):
return PeriodicKernel(ard_num_dims=num_dims, **kwargs)

def test_computes_periodic_function(self):
a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
b = torch.tensor([0, 2], dtype=torch.float).view(2, 1)
Expand Down