diff --git a/tests/test_built_models.py b/tests/test_built_models.py index ef8e48239..9752f8f34 100644 --- a/tests/test_built_models.py +++ b/tests/test_built_models.py @@ -6,8 +6,10 @@ import numpy as np import pandas as pd +import pymc as pm from bambi import math +from bambi.families import Family, Likelihood, Link from bambi.models import Model from bambi.priors import Prior from bambi.terms import GroupSpecificTerm @@ -769,3 +771,31 @@ def test_group_specific_splines(): model = Model("y ~ (bs(x, knots=knots, intercept=False, degree=1)|day)", data=x_check) model.build() + + +def test_2d_response_no_shape(): + """ + This tests whether a model where there's a single linear predictor and a response with + response.ndim > 1 works well, without Bambi causing any shape problems. + See https://github.com/bambinos/bambi/pull/629 + """ + + def fn(name, p, observed, **kwargs): + y = observed[:, 0].flatten() + n = observed[:, 1].flatten() + return pm.Binomial(name, p=p, n=n, observed=y, **kwargs) + + likelihood = Likelihood("CustomBinomial", params=["p"], parent="p", dist=fn) + link = Link("logit") + family = Family("custom-binomial", likelihood, link) + + data = pd.DataFrame( + { + "x": np.array([1.6907, 1.7242, 1.7552, 1.7842, 1.8113, 1.8369, 1.8610, 1.8839]), + "n": np.array([59, 60, 62, 56, 63, 59, 62, 60]), + "y": np.array([6, 13, 18, 28, 52, 53, 61, 60]), + } + ) + + model = Model("prop(y, n) ~ x", data, family=family) + model.fit(draws=10, tune=10) diff --git a/tests/test_plots.py b/tests/test_plots.py index f50b3656c..4a10dbf16 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -162,11 +162,10 @@ def test_multiple_outputs(): y = rng.gamma(shape, np.exp(a + b * x) / shape, N) data_gamma = pd.DataFrame({"x": x, "y": y}) - formula = Formula("y ~ x", "alpha ~ x") model = Model(formula, data_gamma, family="gamma") idata = model.fit(tune=100, draws=100, random_seed=1234) - # Test default target + # Test default target plot_cap(model, idata, "x") # Test user supplied target argument - plot_cap(model, idata, "x", "alpha") \ No newline at end of file + plot_cap(model, idata, "x", "alpha")