diff --git a/CHANGELOG.md b/CHANGELOG.md index 79538d5db..566ead47a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ ### Maintenance and fixes +* Fix bug in predictions with models using HSGP (#780) + ### Documentation * Our Code of Conduct now includes how to send a report (#783) diff --git a/bambi/interpret/utils.py b/bambi/interpret/utils.py index a56e23560..47bc02864 100644 --- a/bambi/interpret/utils.py +++ b/bambi/interpret/utils.py @@ -236,11 +236,20 @@ def get_model_covariates(model: Model) -> np.ndarray: for term in terms.values(): if hasattr(term, "components"): for component in term.components: - # if the component is a function call, use the argument names + # if the component is a function call, look for relevant argument names if isinstance(component, Call): + # Add variable names passed as unnamed arguments covariates.append( [arg.name for arg in component.call.args if isinstance(arg, LazyVariable)] ) + # Add variable names passed as named arguments + covariates.append( + [ + kwarg_value.name + for kwarg_value in component.call.kwargs.values() + if isinstance(kwarg_value, LazyVariable) + ] + ) else: covariates.append([component.name]) elif hasattr(term, "factor"): diff --git a/bambi/model_components.py b/bambi/model_components.py index f4691e5e2..44c781127 100644 --- a/bambi/model_components.py +++ b/bambi/model_components.py @@ -239,11 +239,12 @@ def predict_common( X = np.delete(X, term_slice, axis=1) # Add HSGP components contribution to the linear predictor + hsgp_slices = [] for term_name, term in self.hsgp_terms.items(): # Extract data for the HSGP component from the design matrix term_slice = self.design.common.slices[term_name] x_slice = X[:, term_slice] - X = np.delete(X, term_slice, axis=1) + hsgp_slices.append(term_slice) term_aliased_name = get_aliased_name(term) hsgp_to_stack_dims = (f"{term_aliased_name}_weights_dim",) @@ -288,6 +289,12 @@ def predict_common( # Add contribution to the linear predictor linear_predictor += hsgp_contribution + # Remove columns of X that are associated with HSGP contributions + # All the slices _must be_ deleted at the same time. Otherwise the slice objects don't + # reflect the right columns of X at the time they're used + if hsgp_slices: + X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1) + if self.common_terms or self.intercept_term: # Create DataArray X_terms = [get_aliased_name(term) for term in self.common_terms.values()] diff --git a/pyproject.toml b/pyproject.toml index 1203e3fdc..c7620d891 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ maintainers = [ dependencies = [ "arviz>=0.12.0", - "formulae>=0.5.0", + "formulae>=0.5.3", "graphviz", "pandas>=1.0.0", "pymc>=5.5.0", diff --git a/tests/test_hsgp.py b/tests/test_hsgp.py index 30bf5ce1c..770c70cc5 100644 --- a/tests/test_hsgp.py +++ b/tests/test_hsgp.py @@ -300,3 +300,35 @@ def test_minimal_1d_predicts(data_1d_single_group): new_idata = model.predict(idata, data=new_data, kind="pps", inplace=False) assert new_idata.posterior_predictive["y"].dims == ("chain", "draw", "y_obs") assert new_idata.posterior_predictive["y"].to_numpy().shape == (2, 500, 10) + + +def test_multiple_hsgp_and_by(data_1d_multiple_groups): + rng = np.random.default_rng(1234) + df = data_1d_multiple_groups.copy() + df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0]) + + formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)" + model = bmb.Model( + formula=formula, + data=df, + categorical=["fac"], + ) + idata = model.fit(tune=400, draws=200, target_accept=0.9) + + bmb.interpret.plot_predictions( + model, + idata, + conditional="x1", + subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"}, + ); + + bmb.interpret.plot_predictions( + model, + idata, + conditional={ + "x1": np.linspace(0, 1, num=100), + "fac2": ["a", "b", "c"] + }, + legend=False, + subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"}, + ); \ No newline at end of file