Skip to content

Commit

Permalink
Fix Static check in Gaussian likelihood (#484)
Browse files Browse the repository at this point in the history
* Fix passing Static obs_stddev to Gaussian likelihood

* Fix intro_to_kernels.py example
  • Loading branch information
stefanocortinovis authored Nov 1, 2024
1 parent 823c931 commit 0816901
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

likelihood = gpx.likelihoods.Gaussian(
num_datapoints=D.n, obs_stddev=PositiveReal(value=jnp.array(1e-3), tag="Static")
num_datapoints=D.n, obs_stdev=Static(jnp.array(1e-3))
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value

no_opt_posterior = prior * likelihood
Expand Down
8 changes: 3 additions & 5 deletions gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
GHQuadratureIntegrator,
)
from gpjax.parameters import (
Parameter,
PositiveReal,
Static,
)
Expand Down Expand Up @@ -152,10 +151,9 @@ def __init__(
likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
"""
if isinstance(obs_stddev, Parameter):
self.obs_stddev = obs_stddev
else:
self.obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
if not isinstance(obs_stddev, (PositiveReal, Static)):
obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
self.obs_stddev = obs_stddev

super().__init__(num_datapoints, integrator)

Expand Down

0 comments on commit 0816901

Please # to comment.