Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[WIP] Fix HSGP predictions #780

Merged
merged 7 commits into from
Feb 29, 2024

Conversation

tomicapretto
Copy link
Collaborator

@tomicapretto tomicapretto commented Feb 18, 2024

Fixes predictions when HSGP contains a by variable.

TODO: implement tests?

Edit : it closes #776

@GStechschulte
Copy link
Collaborator

Thanks a lot @tomicapretto 👍🏼

@tomicapretto
Copy link
Collaborator Author

@GStechschulte could you try this?

import bambi as bmb
import numpy as np
import pandas as pd

df = pd.read_csv("tests/data/gam_data.csv")

rng = np.random.default_rng(1234)
df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0])

formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)"
model = bmb.Model(formula, df, categorical=["fac"])
idata = model.fit(tune=500, draws=500, target_accept=0.9)

Plot 1

bmb.interpret.plot_predictions(
    model, 
    idata, 
    conditional="x1", 
    subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
);

image

Plot 2

bmb.interpret.plot_predictions(
    model, 
    idata, 
    conditional={
        "x1": np.linspace(0, 1, num=100),
        "fac2": ["a", "b", "c"]
    }, 
    legend=False,
    subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
);

image

I was expecting to get the second plot with the code for the first plot. I think we got the result we got because we first generate the data, and only then, we use the subplot_kwargs? At that point, it's just too late, you only have one value of fac2

@codecov-commenter
Copy link

codecov-commenter commented Feb 23, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 90.16%. Comparing base (b5b9f09) to head (bdb48d8).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #780      +/-   ##
==========================================
+ Coverage   89.86%   90.16%   +0.29%     
==========================================
  Files          46       46              
  Lines        3810     3814       +4     
==========================================
+ Hits         3424     3439      +15     
+ Misses        386      375      -11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@GStechschulte
Copy link
Collaborator

GStechschulte commented Feb 23, 2024

@tomicapretto thanks! Plot 1 is displaying correctly. It is because you are not explicitly passing fac2 to conditional. Which results in, as you stated, a single default value computed for fac2. The single value cannot have any subplots.

This is the behavior both interpret and marginaleffects uses if a covariate was specified in the model, but not passed to conditional.

bmb.interpret.plot_predictions(
    model, 
    idata, 
    conditional=["x1", "fac2"], 
    subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
    legend=False
);

image

@tomicapretto tomicapretto marked this pull request as ready for review February 23, 2024 18:55
@tomicapretto
Copy link
Collaborator Author

Thanks @GStechschulte! I think this is done.

I know the test is actually testing many things at the same time, not just the fix. But I think it's not possible to write a test for the fix in particular, and if possible, it would be so complicated.

Copy link
Collaborator

@GStechschulte GStechschulte left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! However, my knowledge on the implementation of HSGP in Bambi is a bit lacking.

@GStechschulte
Copy link
Collaborator

Thanks @GStechschulte! I think this is done.

I know the test is actually testing many things at the same time, not just the fix. But I think it's not possible to write a test for the fix in particular, and if possible, it would be so complicated.

Yup, I agree. It is also nice to have the text for interpret in there.

@GStechschulte GStechschulte merged commit ff685b7 into bambinos:main Feb 29, 2024
4 checks passed
GStechschulte pushed a commit to GStechschulte/bambi that referenced this pull request Mar 1, 2024
* Delete all HSGP slices at the same time

* Make interpret consider kwargs in function calls

* Update code of conduct (bambinos#783)

* Update code of conduct

* update changelog

* Update formulae to >=0.5.3

* start a test for the hsgp and 'by'

* update changelog
GStechschulte added a commit that referenced this pull request Mar 29, 2024
* use bayeux to access a wide range of samplers

* use bayeux to access a wide range of samplers

* add notebook links to family table (#774)

* access methods programatically

* clean bayeux idata to be consistent with pymc model coords

* rename alternative sampler args in tests

* change docstring to reflect bayeux sampler names

* bayeux dependencies are numpyro/jax/jaxlib/blackjax

* rename idata coords and dims to PyMC model

* add JAX based sampler dependencies

* Update code of conduct (#783)

* Update code of conduct

* update changelog

* [WIP] Fix HSGP predictions (#780)

* Delete all HSGP slices at the same time

* Make interpret consider kwargs in function calls

* Update code of conduct (#783)

* Update code of conduct

* update changelog

* Update formulae to >=0.5.3

* start a test for the hsgp and 'by'

* update changelog

* bayeux 0.1.9 updates

* bump bayeux version

* remove TFP methods, optimizers, and resolve pylint errors

* alternative backends docs

* tests for JAX based samplers except TFP

* add TFP backend example

* add TFP MCMC methods

* don't use flowmc, chees, meads for categorical model

* call model.backend.inference_methods to show list of samplers

* docstring changes

* inference_methods attribute and change JAX random seed

* Add FutureWarning to inference_method parameter

* black formatting and resolve pylint errors

* fix package name

* drop 3.9 and add 3.12 to testing matrix

* change Python versions in requires-python and target-version

* remove python 3.11 black target-version

* pin requires-python to <3.13

* pip upgrade setuptools

* Bump PyMC to 5.12

* Upgrade black and pylint

* remove upgrading of setup tools

---------

Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

plot_predictions breaks with HSGP
3 participants