Skip to content

Commit

Permalink
Ensure LinaerKernel stores ard_num_dims property. (#2635)
Browse files Browse the repository at this point in the history
[Fixes #2633]
  • Loading branch information
gpleiss authored Feb 7, 2025
1 parent da70269 commit b017b9c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gpytorch/kernels/linear_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
variance_constraint: Optional[Interval] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(ard_num_dims=ard_num_dims, **kwargs)
if variance_constraint is None:
variance_constraint = Positive()
self.register_parameter(
Expand Down
1 change: 1 addition & 0 deletions test/kernels/test_linear_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class TestLinearKernelARD(TestLinearKernel):
def test_kernel_ard(self) -> None:
self.kernel_kwargs = {"ard_num_dims": 2}
kernel = self.create_kernel_no_ard()
self.assertEqual(kernel.ard_num_dims, 2)
self.assertEqual(kernel.variance.shape, torch.Size([1, 2]))

def test_computes_linear_function_rectangular(self):
Expand Down

0 comments on commit b017b9c

Please # to comment.