Skip to content

Commit

Permalink
[WIP] Fix HSGP predictions (bambinos#780)
Browse files Browse the repository at this point in the history
* Delete all HSGP slices at the same time

* Make interpret consider kwargs in function calls

* Update code of conduct (bambinos#783)

* Update code of conduct

* update changelog

* Update formulae to >=0.5.3

* start a test for the hsgp and 'by'

* update changelog
  • Loading branch information
tomicapretto authored and GStechschulte committed Mar 1, 2024
1 parent 74b4e8b commit 47bb161
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

### Maintenance and fixes

* Fix bug in predictions with models using HSGP (#780)

### Documentation

* Our Code of Conduct now includes how to send a report (#783)
Expand Down
11 changes: 10 additions & 1 deletion bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,20 @@ def get_model_covariates(model: Model) -> np.ndarray:
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# if the component is a function call, use the argument names
# if the component is a function call, look for relevant argument names
if isinstance(component, Call):
# Add variable names passed as unnamed arguments
covariates.append(
[arg.name for arg in component.call.args if isinstance(arg, LazyVariable)]
)
# Add variable names passed as named arguments
covariates.append(
[
kwarg_value.name
for kwarg_value in component.call.kwargs.values()
if isinstance(kwarg_value, LazyVariable)
]
)
else:
covariates.append([component.name])
elif hasattr(term, "factor"):
Expand Down
9 changes: 8 additions & 1 deletion bambi/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,12 @@ def predict_common(
X = np.delete(X, term_slice, axis=1)

# Add HSGP components contribution to the linear predictor
hsgp_slices = []
for term_name, term in self.hsgp_terms.items():
# Extract data for the HSGP component from the design matrix
term_slice = self.design.common.slices[term_name]
x_slice = X[:, term_slice]
X = np.delete(X, term_slice, axis=1)
hsgp_slices.append(term_slice)
term_aliased_name = get_aliased_name(term)
hsgp_to_stack_dims = (f"{term_aliased_name}_weights_dim",)

Expand Down Expand Up @@ -288,6 +289,12 @@ def predict_common(
# Add contribution to the linear predictor
linear_predictor += hsgp_contribution

# Remove columns of X that are associated with HSGP contributions
# All the slices _must be_ deleted at the same time. Otherwise the slice objects don't
# reflect the right columns of X at the time they're used
if hsgp_slices:
X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1)

if self.common_terms or self.intercept_term:
# Create DataArray
X_terms = [get_aliased_name(term) for term in self.common_terms.values()]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ maintainers = [

dependencies = [
"arviz>=0.12.0",
"formulae>=0.5.0",
"formulae>=0.5.3",
"graphviz",
"pandas>=1.0.0",
"pymc>=5.5.0",
Expand Down
32 changes: 32 additions & 0 deletions tests/test_hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,35 @@ def test_minimal_1d_predicts(data_1d_single_group):
new_idata = model.predict(idata, data=new_data, kind="pps", inplace=False)
assert new_idata.posterior_predictive["y"].dims == ("chain", "draw", "y_obs")
assert new_idata.posterior_predictive["y"].to_numpy().shape == (2, 500, 10)


def test_multiple_hsgp_and_by(data_1d_multiple_groups):
rng = np.random.default_rng(1234)
df = data_1d_multiple_groups.copy()
df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0])

formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)"
model = bmb.Model(
formula=formula,
data=df,
categorical=["fac"],
)
idata = model.fit(tune=400, draws=200, target_accept=0.9)

bmb.interpret.plot_predictions(
model,
idata,
conditional="x1",
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
);

bmb.interpret.plot_predictions(
model,
idata,
conditional={
"x1": np.linspace(0, 1, num=100),
"fac2": ["a", "b", "c"]
},
legend=False,
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
);

0 comments on commit 47bb161

Please # to comment.