Skip to content

Commit

Permalink
load bug with y name different than output_var (#1468)
Browse files Browse the repository at this point in the history
* change the name of the series / array

* modifications to tests to use prior_predictive mocking

* link up y in test

* add the prior to the idata

* add different name in xarray test
  • Loading branch information
wd60622 authored Feb 4, 2025
1 parent 0eb7841 commit eb3b6b6
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 124 deletions.
3 changes: 1 addition & 2 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,7 @@ def create_fit_data(
if isinstance(y, np.ndarray):
y = pd.Series(y, index=X.index, name=self.output_var)

if y.name is None:
y.name = self.output_var
y.name = self.output_var

if isinstance(X, pd.DataFrame):
X = X.to_xarray()
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def mock_sample(*args, **kwargs):
return idata


@pytest.fixture
@pytest.fixture(scope="module")
def mock_pymc_sample():
original_sample = pm.sample
pm.sample = mock_sample
Expand Down
52 changes: 8 additions & 44 deletions tests/customer_choice/test_mv_its.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import re
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -84,38 +83,10 @@ def test_plot_data(saturated_data):
plt.close()


def mock_fit(self, X, y, **kwargs):
self.idata.add_groups(
{
"posterior": self.idata.prior,
}
)

combined_data = pd.concat([X, y.rename(self.output_var)], axis=1)

if "fit_data" in self.idata:
del self.idata.fit_data

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore

return self


@pytest.fixture(scope="module")
def fit_model(module_mocker, saturated_data):
def fit_model(saturated_data, mock_pymc_sample):
model = MVITS(existing_sales=["competitor", "own"], saturated_market=True)

module_mocker.patch(
"pymc_marketing.customer_choice.mv_its.MVITS.fit",
mock_fit,
)

model.sample(
saturated_data.loc[:, ["competitor", "own"]],
saturated_data["new"],
Expand Down Expand Up @@ -151,14 +122,9 @@ def unsaturated_data_good():


@pytest.fixture(scope="module")
def unsaturated_model_bad(module_mocker, unsaturated_data_bad):
def unsaturated_model_bad(unsaturated_data_bad, mock_pymc_sample):
model = MVITS(existing_sales=["competitor", "own"], saturated_market=False)

module_mocker.patch(
"pymc_marketing.customer_choice.mv_its.MVITS.fit",
mock_fit,
)

model.sample(
unsaturated_data_bad.loc[:, ["competitor", "own"]],
unsaturated_data_bad["new"],
Expand All @@ -169,14 +135,9 @@ def unsaturated_model_bad(module_mocker, unsaturated_data_bad):


@pytest.fixture(scope="module")
def unsaturated_model_good(module_mocker, unsaturated_data_good):
def unsaturated_model_good(unsaturated_data_good, mock_pymc_sample):
model = MVITS(existing_sales=["competitor", "own"], saturated_market=False)

module_mocker.patch(
"pymc_marketing.customer_choice.mv_its.MVITS.fit",
mock_fit,
)

model.sample(
unsaturated_data_good.loc[:, ["competitor", "own"]],
unsaturated_data_good["new"],
Expand Down Expand Up @@ -220,8 +181,11 @@ def test_save_load(fit_model, saturated_data) -> None:
assert loaded.saturated_market == fit_model.saturated_market
assert loaded.X.columns.name is None
pd.testing.assert_frame_equal(loaded.X, fit_model.X, check_names=False)
assert loaded.y.name == fit_model.output_var
pd.testing.assert_series_equal(loaded.y.rename("new"), saturated_data["new"])
assert loaded.y.name == fit_model.y.name
pd.testing.assert_series_equal(
loaded.y,
saturated_data["new"].rename(fit_model.output_var),
)


@pytest.mark.parametrize("variable", ["y", "mu"])
Expand Down
7 changes: 6 additions & 1 deletion tests/mmm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,12 @@ def test_calling_prior_predictive_before_fit_raises_error(test_mmm, toy_X, toy_y
test_mmm.prior_predictive


def test_calling_fit_result_before_fit_raises_error(test_mmm, toy_X, toy_y):
def test_calling_fit_result_before_fit_raises_error(
test_mmm,
toy_X,
toy_y,
mock_pymc_sample,
):
# Arrange
test_mmm.idata = None
with pytest.raises(
Expand Down
Loading

0 comments on commit eb3b6b6

Please # to comment.