Skip to content

Implement model transform to remove minibatching operations from graph #7521

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
aphc14 opened this issue Oct 2, 2024 · 4 comments · Fixed by #7746
Closed

Implement model transform to remove minibatching operations from graph #7521

aphc14 opened this issue Oct 2, 2024 · 4 comments · Fixed by #7746

Comments

@aphc14
Copy link

aphc14 commented Oct 2, 2024

Description

When using pm.Minibatch, the pm.sample_posterior_predictive returns predictions with the size of the minibatch instead of the full dataset size. To make predictions on the full dataset requires the previous trace to be passed into a new model with a similar setup. For complicated models, this would add several lines of code to create a new model that is almost identical to the previous model.

This enhancement would make it easier to perform posterior predictive checks when using minibatch.

relates to: https://discourse.pymc.io/t/minibatch-not-working/14061/10

Example scenario:

import numpy as np
import pymc as pm
import arviz as az
import pytensor.tensor as pt

# generate data
N = 10000
P = 3
rng = np.random.default_rng(88)
X_full = rng.uniform(2, 10, size=(N, 3))
beta = np.array([1.5, 0.2, -0.9])
y_full = np.matmul(X_full, beta) + rng.normal(0, 1, size=(N,))

Before:

# minibatch
X_mb, y_mb = pm.Minibatch(X_full, y_full, batch_size=100)

# original minibatch model
with pm.Model() as model_mb:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pm.Deterministic("mu", pt.matmul(X_mb, b))
    likelihood = pm.Normal(
        "likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
    )

    fit_mb = pm.fit(
        n=100000,
        method="advi",
        progressbar=True,
        callbacks=[pm.callbacks.CheckParametersConvergence()],
        random_seed=88,
    )
    idata_mb = fit_mb.sample(500)

    pm.sample_posterior_predictive(idata_mb, extend_inferencedata=True)
    idata_mb.posterior = pm.compute_deterministics(
        idata_mb.posterior, merge_dataset=True
    )

# new but similar model to the original
with pm.Model() as model_preds:
    X = pm.Data("X", X_full)
    y = pm.Data("y", y_full)

    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pm.Deterministic("mu", pt.matmul(X, b))
    likelihood = pm.Normal("likelihood", mu=mu, sigma=sigma, observed=y)
    
with model_preds:    
    pm.set_data({"X": X_full})
    ypreds = pm.sample_posterior_predictive(idata_mb)

print(f"Minibatch: {idata_mb.posterior_predictive.likelihood.sizes}")
print(f"Full Data: {ypreds.posterior_predictive.likelihood.sizes}")

# output
Minibatch: Frozen({'chain': 1, 'draw': 500, 'likelihood_dim_2': 100})
Full Data: Frozen({'chain': 1, 'draw': 500, 'likelihood_dim_2': 10000})
@aphc14 aphc14 added the bug label Oct 2, 2024
Copy link

welcome bot commented Oct 2, 2024

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 2, 2024

This is the correct behavior. Minibatch is defined as a stochastic slice of the random variable. You can define a model without minibatch and use the old trace to sample a full size dataset with posterior predictive

@aphc14 aphc14 changed the title BUG: Minibatch posterior predictive sampling returns incorrect data size ENH: Minibatch posterior predictive sampling to return predictions of original data Oct 3, 2024
@aphc14
Copy link
Author

aphc14 commented Oct 4, 2024

I have edited the original post to describe the issue as an enhancement/feature request rather than a bug. I'm not able to modify labels on my end though.

@ricardoV94 ricardoV94 changed the title ENH: Minibatch posterior predictive sampling to return predictions of original data Implement model transform to remove minibatching operations from graph Oct 4, 2024
@ricardoV94
Copy link
Member

I have edited the original post to describe the issue as an enhancement/feature request rather than a bug. I'm not able to modify labels on my end though.

Thanks I did that. Note that my suggestion of a model transform wouldn't do anything automatically. The API would be something like:

with pm.Model() as minibatch_m:
  ... # define model with minibatch
  idata = pm.sample()
  
with remove_minibatch(minibatch_m) as m:
  pm.sample_posterior_predictive(idata)

# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants