Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bug] get_fantasy_likelihood method broken for DirichletClassificationLikelihood #2579

Open
SaiAakash opened this issue Sep 7, 2024 · 0 comments
Labels

Comments

@SaiAakash
Copy link
Contributor

🐛 Bug

Conditioning on new observations for a multi-class classification model with DirichletClassificationLikelihood throws an error.

To reproduce

** Code snippet to reproduce **

import torch
import numpy as np
import gpytorch

from gpytorch.models import ExactGP
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel


def gen_data(num_data, seed=2019):
    torch.random.manual_seed(seed)

    x = torch.randn(num_data, 1)
    y = torch.randn(num_data, 1)

    u = torch.rand(1)
    data_fn = lambda x, y: 1 * torch.sin(0.15 * u * 3.1415 * (x + y)) + 1
    latent_fn = data_fn(x, y)
    z = torch.round(latent_fn).long().squeeze()
    return torch.cat((x, y), dim=1), z, data_fn


train_x, train_y, genfn = gen_data(500)


# We will use the simplest form of GP model, exact inference
class DirichletGPModel(ExactGP):
    def __init__(self, train_x, train_y, likelihood, num_classes):
        super(DirichletGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean(batch_shape=torch.Size((num_classes,)))
        self.covar_module = ScaleKernel(
            RBFKernel(batch_shape=torch.Size((num_classes,))),
            batch_shape=torch.Size((num_classes,)),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


# initialize likelihood and model
# we let the DirichletClassificationLikelihood compute the targets for us
likelihood = DirichletClassificationLikelihood(train_y, learn_additional_noise=True)
model = DirichletGPModel(
    train_x,
    likelihood.transformed_targets,
    likelihood,
    num_classes=likelihood.num_classes,
)

# Training loop
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.1
)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(50):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, likelihood.transformed_targets).sum()
    loss.backward()
    if i % 5 == 0:
        print(
            "Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f"
            % (
                i + 1,
                50,
                loss.item(),
                model.covar_module.base_kernel.lengthscale.mean().item(),
                model.likelihood.second_noise_covar.noise.mean().item(),
            )
        )
    optimizer.step()


model.eval()
likelihood.eval()

with gpytorch.settings.fast_pred_var(), torch.no_grad():
    test_dist = model(train_x)

    pred_means = test_dist.loc

# Fantasize on new observations
new_xy, new_z, genfn = gen_data(20, seed=2000)
_, new_z, num_classes = likelihood._prepare_targets(new_z.unsqueeze(0))
updated_model = model.get_fantasy_model(new_xy, new_z.T)

** Stack trace/error message **

{
	"name": "RuntimeError",
	"message": "FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[39], line 99
     97 new_xy, new_z, genfn = gen_data(20, seed=2000)
     98 _, new_z, num_classes = likelihood._prepare_targets(new_z.unsqueeze(0))
---> 99 updated_model = model.get_fantasy_model(new_xy, new_z.T)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:238, in ExactGP.get_fantasy_model(self, inputs, targets, **kwargs)
    235 self.train_targets = old_train_targets
    236 self.likelihood = old_likelihood
--> 238 new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
    239 new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
    240     inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs
    241 )
    243 # if the fantasies are at the same points, we need to expand the inputs for the new model

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/likelihoods/gaussian_likelihood.py:439, in DirichletClassificationLikelihood.get_fantasy_likelihood(self, **kwargs)
    435 def get_fantasy_likelihood(self, **kwargs: Any) -> \"DirichletClassificationLikelihood\":
    436     # we assume that the number of classes does not change.
    438     if \"targets\" not in kwargs:
--> 439         raise RuntimeError(\"FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg\")
    441     old_noise_covar = self.noise_covar
    442     self.noise_covar = None  # pyre-fixme[8]

RuntimeError: FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg"
}

Expected Behavior

Should return the updated model with the fantasized likelihood.

System information

Please complete the following information:

  • GPyTorch version: 1.12
  • PyTorch version: 2.4.0
  • OS: macOS Sonoma 14.5

Additional context

I can see that a Runtime Error is raised in the get_fantasy_likelihood method of DirichletClassificationLikelihood for the absence of targets in kwargs. However, I can't see targets being used anywhere in that method. Also, it is not possible to pass a kwarg called targets because the get_fantasy_likelihood method is called inside the get_fantasy_model method of the ExactGP class and this method already takes in a separate targets argument. So basically, the same function cannot take two arguments called targets.

@SaiAakash SaiAakash added the bug label Sep 7, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant