From 08169019fbd2a0583b4277172fa1ce3f404cdc5b Mon Sep 17 00:00:00 2001 From: Stefano <46034160+stefanocortinovis@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:22:31 +0000 Subject: [PATCH] Fix `Static` check in `Gaussian` likelihood (#484) * Fix passing Static obs_stddev to Gaussian likelihood * Fix intro_to_kernels.py example --- examples/intro_to_kernels.py | 2 +- gpjax/likelihoods.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/intro_to_kernels.py b/examples/intro_to_kernels.py index b4340ffac..7fabec09e 100644 --- a/examples/intro_to_kernels.py +++ b/examples/intro_to_kernels.py @@ -246,7 +246,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821 prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) likelihood = gpx.likelihoods.Gaussian( - num_datapoints=D.n, obs_stddev=PositiveReal(value=jnp.array(1e-3), tag="Static") + num_datapoints=D.n, obs_stdev=Static(jnp.array(1e-3)) ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value no_opt_posterior = prior * likelihood diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 8336b7c32..ba831c81b 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -28,7 +28,6 @@ GHQuadratureIntegrator, ) from gpjax.parameters import ( - Parameter, PositiveReal, Static, ) @@ -152,10 +151,9 @@ def __init__( likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically. """ - if isinstance(obs_stddev, Parameter): - self.obs_stddev = obs_stddev - else: - self.obs_stddev = PositiveReal(jnp.asarray(obs_stddev)) + if not isinstance(obs_stddev, (PositiveReal, Static)): + obs_stddev = PositiveReal(jnp.asarray(obs_stddev)) + self.obs_stddev = obs_stddev super().__init__(num_datapoints, integrator)