Skip to content

Commit

Permalink
json.loads with python types bug (pymc-labs#881)
Browse files Browse the repository at this point in the history
* loads doesnt support boolean

* defaults for the media transformation

* test for the time_varyign
  • Loading branch information
wd60622 authored and radiokosmos committed Sep 1, 2024
1 parent 395b285 commit 3304be4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,14 +632,14 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
"control_columns": json.loads(attrs["control_columns"]),
"channel_columns": json.loads(attrs["channel_columns"]),
"adstock_max_lag": json.loads(attrs["adstock_max_lag"]),
"adstock": json.loads(attrs.get("adstock", "geometric")),
"saturation": json.loads(attrs.get("saturation", "logistic")),
"adstock_first": json.loads(attrs.get("adstock_first", True)),
"adstock": json.loads(attrs.get("adstock", '"geometric"')),
"saturation": json.loads(attrs.get("saturation", '"logistic"')),
"adstock_first": json.loads(attrs.get("adstock_first", "true")),
"yearly_seasonality": json.loads(attrs["yearly_seasonality"]),
"time_varying_intercept": json.loads(
attrs.get("time_varying_intercept", False)
attrs.get("time_varying_intercept", "false")
),
"time_varying_media": json.loads(attrs.get("time_varying_media", False)),
"time_varying_media": json.loads(attrs.get("time_varying_media", "false")),
"validate_data": json.loads(attrs["validate_data"]),
"sampler_config": json.loads(attrs["sampler_config"]),
}
Expand Down
44 changes: 44 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,3 +1115,47 @@ def test_save_load_with_tvp(

# clean up
os.remove(file)


def test_missing_attrs_to_defaults(toy_X, toy_y) -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
control_columns=["control_1", "control_2"],
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
adstock_first=False,
time_varying_intercept=False,
time_varying_media=False,
)
mmm = mock_fit(mmm, toy_X, toy_y)
mmm.idata.attrs.pop("adstock")
mmm.idata.attrs.pop("saturation")
mmm.idata.attrs.pop("adstock_first")
mmm.idata.attrs.pop("time_varying_intercept")
mmm.idata.attrs.pop("time_varying_media")

file = "tmp-model"
mmm.save(file)

loaded_mmm = MMM.load(file)

attrs = loaded_mmm.idata.attrs
for key in [
"adstock",
"saturation",
"adstock_first",
"time_varying_intercept",
"time_varying_media",
]:
assert key not in attrs

assert loaded_mmm.adstock.lookup_name == "geometric"
assert loaded_mmm.saturation.lookup_name == "logistic"
assert not loaded_mmm.time_varying_intercept
assert not loaded_mmm.time_varying_media
# Falsely loaded
assert loaded_mmm.adstock_first

# clean up
os.remove(file)

0 comments on commit 3304be4

Please # to comment.