Skip to content

Commit

Permalink
[MMM] Model events as gaussian bumps (#1465)
Browse files Browse the repository at this point in the history
* events effects gaussian bumps

* cleanup

* cleanup2

* cleanup3

* some tests

* add example doc

* change the docs

---------

Co-authored-by: Will Dean <wd60622@gmail.com>
Co-authored-by: Will Dean <57733339+wd60622@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 4, 2025
1 parent eb3b6b6 commit 542a85b
Show file tree
Hide file tree
Showing 2 changed files with 639 additions and 0 deletions.
231 changes: 231 additions & 0 deletions pymc_marketing/mmm/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Event transformations.
This module provides event transformations for use in Marketing Mix Models.
.. plot::
:context: close-figs
import numpy as np
import pandas as pd
import pymc as pm
import matplotlib.pyplot as plt
from pymc_marketing.mmm.events import EventEffect, GaussianBasis
from pymc_marketing.plot import plot_curve
from pymc_marketing.prior import Prior
seed = sum(map(ord, "Events"))
rng = np.random.default_rng(seed)
df_events = pd.DataFrame(
{
"event": ["single day", "multi day"],
"start_date": pd.to_datetime(["2025-01-01", "2025-01-20"]),
"end_date": pd.to_datetime(["2025-01-02", "2025-01-25"]),
}
)
def difference_in_days(model_dates, event_dates):
if hasattr(model_dates, "to_numpy"):
model_dates = model_dates.to_numpy()
if hasattr(event_dates, "to_numpy"):
event_dates = event_dates.to_numpy()
one_day = np.timedelta64(1, "D")
return (model_dates[:, None] - event_dates) / one_day
def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
start_dates = df_events["start_date"]
end_dates = df_events["end_date"]
s_ref = difference_in_days(model_dates, start_dates)
e_ref = difference_in_days(model_dates, end_dates)
return np.where(
(s_ref >= 0) & (e_ref <= 0),
0,
np.where(np.abs(s_ref) < np.abs(e_ref), s_ref, e_ref),
)
gaussian = GaussianBasis(
priors={
"sigma": Prior("Gamma", mu=7, sigma=1, dims="event"),
}
)
effect_size = Prior("Normal", mu=1, sigma=1, dims="event")
effect = EventEffect(basis=gaussian, effect_size=effect_size, dims=("event",))
dates = pd.date_range("2024-12-01", periods=3 * 31, freq="D")
X = create_basis_matrix(df_events, model_dates=dates)
coords = {"date": dates, "event": df_events["event"].to_numpy()}
with pm.Model(coords=coords) as model:
pm.Deterministic("effect", effect.apply(X), dims=("date", "event"))
idata = pm.sample_prior_predictive(random_seed=rng)
fig, axes = idata.prior.effect.pipe(
plot_curve,
{"date"},
subplot_kwargs={"ncols": 1},
sample_kwargs={"rng": rng},
)
fig.suptitle("Gaussian Event Effect")
plt.show()
"""

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import BaseModel, Field, InstanceOf, validate_call
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.deserialize import deserialize, register_deserialization
from pymc_marketing.mmm.components.base import Transformation, create_registration_meta
from pymc_marketing.prior import Prior, create_dim_handler

BASIS_TRANSFORMATIONS: dict = {}
BasisMeta = create_registration_meta(BASIS_TRANSFORMATIONS)


class Basis(Transformation, metaclass=BasisMeta): # type: ignore[misc]
"""Basis transformation associated with an event model."""

prefix: str = "basis"
lookup_name: str

@validate_call
def sample_curve(
self,
parameters: InstanceOf[xr.Dataset] = Field(
..., description="Parameters of the saturation transformation."
),
days: int = Field(0, ge=0, description="Number of days around basis."),
) -> xr.DataArray:
"""Sample the curve of the saturation transformation given parameters.
Parameters
----------
parameters : xr.Dataset
Dataset with the parameters of the saturation transformation.
days : int
Number of days around basis.
Returns
-------
xr.DataArray
Curve of the saturation transformation.
"""
x = np.linspace(-days, days, 100)

coords = {"x": x}

return self._sample_curve(
var_name="saturation",
parameters=parameters,
x=x,
coords=coords,
)


def basis_from_dict(data: dict) -> Basis:
"""Create a basis transformation from a dictionary."""
data = data.copy()
lookup_name = data.pop("lookup_name")
cls = BASIS_TRANSFORMATIONS[lookup_name]

if "priors" in data:
data["priors"] = {k: deserialize(v) for k, v in data["priors"].items()}

return cls(**data)


def _is_basis(data):
return "lookup_name" in data and data["lookup_name"] in BASIS_TRANSFORMATIONS


register_deserialization(
is_type=_is_basis,
deserialize=basis_from_dict,
)


class EventEffect(BaseModel):
"""Event effect associated with an event model."""

basis: InstanceOf[Basis]
effect_size: InstanceOf[Prior]
dims: tuple[str, ...]

def apply(self, X: pt.TensorLike, name: str = "event") -> TensorVariable:
"""Apply the event effect to the data."""
dim_handler = create_dim_handler(("x", *self.dims))
return self.basis.apply(X, dims=self.dims) * dim_handler(
self.effect_size.create_variable(f"{name}_effect_size"),
self.effect_size.dims,
)

def to_dict(self) -> dict:
"""Convert the event effect to a dictionary."""
return {
"class": "EventEffect",
"data": {
"basis": self.basis.to_dict(),
"effect_size": self.effect_size.to_dict(),
"dims": self.dims,
},
}

@classmethod
def from_dict(cls, data: dict) -> "EventEffect":
"""Create an event effect from a dictionary."""
return cls(
basis=deserialize(data["basis"]),
effect_size=deserialize(data["effect_size"]),
dims=data["dims"],
)


def _is_event_effect(data: dict) -> bool:
"""Check if the data is an event effect."""
return data["class"] == "EventEffect"


register_deserialization(
is_type=_is_event_effect,
deserialize=lambda data: EventEffect.from_dict(data["data"]),
)


class GaussianBasis(Basis):
"""Gaussian basis transformation."""

lookup_name = "gaussian"

def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
"""Gaussian bump function."""
return pm.math.exp(-0.5 * (x / sigma) ** 2)

default_priors = {
"sigma": Prior("Gamma", mu=7, sigma=1),
}
Loading

0 comments on commit 542a85b

Please # to comment.