Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Implement nutpie as an external sampler #7719

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

twiecki
Copy link
Member

@twiecki twiecki commented Mar 10, 2025

Summary

  • Fix parameter name mapping in NutPie integration (max_tree_depth → maxdepth)
  • Correctly handle InferenceData object returned by NutPie
  • Fix error in coords_and_dims_for_inferencedata function call
  • Handle progressbar parameter correctly
  • Simplify convergence checks to handle NutPie's different structure

Test plan

  • Run tests in tests/step_methods/test_external.py to verify NutPie integration works
  • Fixed parameter naming bugs with NutPie API

🤖 Generated with Claude Code


📚 Documentation preview 📚: https://pymc--7719.org.readthedocs.build/en/7719/

- Change max_tree_depth to maxdepth to match NutPie's API
- Fix coords_and_dims_for_inferencedata function call
- Handle progressbar parameter correctly
- Skip conversion to InferenceData since NutPie already returns one
- Simplify convergence checks to handle NutPie's different structure

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
@twiecki twiecki requested a review from ricardoV94 March 11, 2025 10:18
@twiecki twiecki changed the title Fix NutPie external sampler parameter mapping Implement nutpie as an external sampler Mar 22, 2025
@twiecki
Copy link
Member Author

twiecki commented Mar 22, 2025

@ricardoV94

@ricardoV94
Copy link
Member

@ricardoV94

It's on my stack :)

NUTPIE_AVAILABLE = False


class NutPie(ExternalSampler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be Nutpie. "pie" is not capitalized in the nutpie docs, anyway.

image

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is missing here, there will definitely have to be changes in the sample function.

special handling in PyMC's sampling loops.
"""

is_external = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class already tells us this is an ExternalSampler, no need for is_external?

logger = logging.getLogger("pymc")

try:
import nutpie
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nutpie should only be imported if the step method is used/instantiated, so as to avoid import time penalty


# Create a NutPie model
logger.info("Compiling NutPie model")
nutpie_model = nutpie.compile_pymc_model(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to just call nutpie.sample?

Comment on lines +197 to +222
# Convert to InferenceData
if idata_kwargs is None:
idata_kwargs = {}

# Extract relevant variables and data for InferenceData
coords, dims = coords_and_dims_for_inferencedata(model)
constants_data = find_constants(model)
observed_data = find_observations(model)

# Always include sampler stats
if "include_sampler_stats" not in idata_kwargs:
idata_kwargs["include_sampler_stats"] = True

# NutPie already returns an InferenceData object
idata = nutpie_trace

# Set tuning steps attribute if possible
try:
idata.posterior.attrs["tuning_steps"] = tune
except (AttributeError, KeyError):
logger.warning("Could not set tuning_steps attribute on InferenceData")

# Skip compute_convergence_checks for now
# NutPie's InferenceData structure is different from PyMC's expectations

return idata
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of this is handled by nutpie

Enum indicating competence level for this variable
"""
if var.dtype in continuous_types and has_grad:
return Competence.IDEAL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be IDEAL



@pytest.mark.skipif(not NUTPIE_AVAILABLE, reason="NutPie not installed")
def test_nutpie_jax_backend():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reuse existing tests instead of defining new ones

@ricardoV94
Copy link
Member

The new test file would have to be added to the yaml file (see failing pre-commit), and I'm sure they are failing looking at it

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants