-
-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[WIP] Fix HSGP predictions #780
[WIP] Fix HSGP predictions #780
Conversation
Thanks a lot @tomicapretto 👍🏼 |
* Update code of conduct * update changelog
…nto fix_hsgp_prediction
@GStechschulte could you try this? import bambi as bmb
import numpy as np
import pandas as pd
df = pd.read_csv("tests/data/gam_data.csv")
rng = np.random.default_rng(1234)
df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0])
formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)"
model = bmb.Model(formula, df, categorical=["fac"])
idata = model.fit(tune=500, draws=500, target_accept=0.9) Plot 1 bmb.interpret.plot_predictions(
model,
idata,
conditional="x1",
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
); Plot 2 bmb.interpret.plot_predictions(
model,
idata,
conditional={
"x1": np.linspace(0, 1, num=100),
"fac2": ["a", "b", "c"]
},
legend=False,
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
); I was expecting to get the second plot with the code for the first plot. I think we got the result we got because we first generate the data, and only then, we use the |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #780 +/- ##
==========================================
+ Coverage 89.86% 90.16% +0.29%
==========================================
Files 46 46
Lines 3810 3814 +4
==========================================
+ Hits 3424 3439 +15
+ Misses 386 375 -11 ☔ View full report in Codecov by Sentry. |
@tomicapretto thanks! Plot 1 is displaying correctly. It is because you are not explicitly passing This is the behavior both bmb.interpret.plot_predictions(
model,
idata,
conditional=["x1", "fac2"],
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
legend=False
); |
Thanks @GStechschulte! I think this is done. I know the test is actually testing many things at the same time, not just the fix. But I think it's not possible to write a test for the fix in particular, and if possible, it would be so complicated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! However, my knowledge on the implementation of HSGP in Bambi is a bit lacking.
Yup, I agree. It is also nice to have the text for |
* Delete all HSGP slices at the same time * Make interpret consider kwargs in function calls * Update code of conduct (bambinos#783) * Update code of conduct * update changelog * Update formulae to >=0.5.3 * start a test for the hsgp and 'by' * update changelog
* use bayeux to access a wide range of samplers * use bayeux to access a wide range of samplers * add notebook links to family table (#774) * access methods programatically * clean bayeux idata to be consistent with pymc model coords * rename alternative sampler args in tests * change docstring to reflect bayeux sampler names * bayeux dependencies are numpyro/jax/jaxlib/blackjax * rename idata coords and dims to PyMC model * add JAX based sampler dependencies * Update code of conduct (#783) * Update code of conduct * update changelog * [WIP] Fix HSGP predictions (#780) * Delete all HSGP slices at the same time * Make interpret consider kwargs in function calls * Update code of conduct (#783) * Update code of conduct * update changelog * Update formulae to >=0.5.3 * start a test for the hsgp and 'by' * update changelog * bayeux 0.1.9 updates * bump bayeux version * remove TFP methods, optimizers, and resolve pylint errors * alternative backends docs * tests for JAX based samplers except TFP * add TFP backend example * add TFP MCMC methods * don't use flowmc, chees, meads for categorical model * call model.backend.inference_methods to show list of samplers * docstring changes * inference_methods attribute and change JAX random seed * Add FutureWarning to inference_method parameter * black formatting and resolve pylint errors * fix package name * drop 3.9 and add 3.12 to testing matrix * change Python versions in requires-python and target-version * remove python 3.11 black target-version * pin requires-python to <3.13 * pip upgrade setuptools * Bump PyMC to 5.12 * Upgrade black and pylint * remove upgrading of setup tools --------- Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
Fixes predictions when HSGP contains a
by
variable.get_model_covariates
so it also looks at the named arguments of function calls.TODO: implement tests?Edit : it closes #776