-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Implement nutpie as an external sampler #7719
base: main
Are you sure you want to change the base?
Conversation
- Change max_tree_depth to maxdepth to match NutPie's API - Fix coords_and_dims_for_inferencedata function call - Handle progressbar parameter correctly - Skip conversion to InferenceData since NutPie already returns one - Simplify convergence checks to handle NutPie's different structure 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
It's on my stack :) |
NUTPIE_AVAILABLE = False | ||
|
||
|
||
class NutPie(ExternalSampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something is missing here, there will definitely have to be changes in the sample
function.
special handling in PyMC's sampling loops. | ||
""" | ||
|
||
is_external = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base class already tells us this is an ExternalSampler, no need for is_external
?
logger = logging.getLogger("pymc") | ||
|
||
try: | ||
import nutpie |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nutpie should only be imported if the step method is used/instantiated, so as to avoid import time penalty
|
||
# Create a NutPie model | ||
logger.info("Compiling NutPie model") | ||
nutpie_model = nutpie.compile_pymc_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to just call nutpie.sample
?
# Convert to InferenceData | ||
if idata_kwargs is None: | ||
idata_kwargs = {} | ||
|
||
# Extract relevant variables and data for InferenceData | ||
coords, dims = coords_and_dims_for_inferencedata(model) | ||
constants_data = find_constants(model) | ||
observed_data = find_observations(model) | ||
|
||
# Always include sampler stats | ||
if "include_sampler_stats" not in idata_kwargs: | ||
idata_kwargs["include_sampler_stats"] = True | ||
|
||
# NutPie already returns an InferenceData object | ||
idata = nutpie_trace | ||
|
||
# Set tuning steps attribute if possible | ||
try: | ||
idata.posterior.attrs["tuning_steps"] = tune | ||
except (AttributeError, KeyError): | ||
logger.warning("Could not set tuning_steps attribute on InferenceData") | ||
|
||
# Skip compute_convergence_checks for now | ||
# NutPie's InferenceData structure is different from PyMC's expectations | ||
|
||
return idata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all of this is handled by nutpie
Enum indicating competence level for this variable | ||
""" | ||
if var.dtype in continuous_types and has_grad: | ||
return Competence.IDEAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be IDEAL
|
||
|
||
@pytest.mark.skipif(not NUTPIE_AVAILABLE, reason="NutPie not installed") | ||
def test_nutpie_jax_backend(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reuse existing tests instead of defining new ones
The new test file would have to be added to the yaml file (see failing pre-commit), and I'm sure they are failing looking at it |
Summary
Test plan
tests/step_methods/test_external.py
to verify NutPie integration works🤖 Generated with Claude Code
📚 Documentation preview 📚: https://pymc--7719.org.readthedocs.build/en/7719/