Skip to content

Commit

Permalink
Merge pull request #2605 from SaiAakash/fix_fantasy_model
Browse files Browse the repository at this point in the history
Detach `new_covar_cache` to enable JIT tracing of models after fantasization
  • Loading branch information
Balandat authored Jan 25, 2025
2 parents 4156bf4 + 3437156 commit 0bace60
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion gpytorch/models/exact_prediction_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
# now update the root and root inverse
new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar)
new_root = new_lt.root_decomposition().root
new_covar_cache = new_lt.root_inv_decomposition().root
if settings.detach_test_caches.on():
new_covar_cache = new_lt.root_inv_decomposition().root.detach()
else:
new_covar_cache = new_lt.root_inv_decomposition().root

# Expand inputs accordingly if necessary (for fantasies at the same points)
if full_inputs[0].dim() <= full_targets.dim():
Expand Down

0 comments on commit 0bace60

Please # to comment.