diff --git a/causalpy/pymc_experiments.py b/causalpy/pymc_experiments.py index 502c0d41..c07bb1a6 100644 --- a/causalpy/pymc_experiments.py +++ b/causalpy/pymc_experiments.py @@ -26,7 +26,7 @@ from causalpy.custom_exceptions import BadIndexException # NOQA from causalpy.custom_exceptions import DataException, FormulaException from causalpy.plot_utils import plot_xY -from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels +from causalpy.utils import _is_variable_dummy_coded LEGEND_FONT_SIZE = 12 az.style.use("arviz-darkgrid") @@ -978,7 +978,8 @@ class PrePostNEGD(ExperimentalDesign): :param formula: A statistical model formula :param group_variable_name: - Name of the column in data for the group variable + Name of the column in data for the group variable, should be either + binary or boolean :param pretreatment_variable_name: Name of the column in data for the pretreatment variable :param model: @@ -1058,17 +1059,19 @@ def __init__( self.group_variable_name: np.zeros(self.pred_xi.shape), } ) - (new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated) - self.pred_untreated = self.model.predict(X=np.asarray(new_x)) + (new_x_untreated,) = build_design_matrices( + [self._x_design_info], x_pred_untreated + ) + self.pred_untreated = self.model.predict(X=np.asarray(new_x_untreated)) # treated - x_pred_untreated = pd.DataFrame( + x_pred_treated = pd.DataFrame( { self.pretreatment_variable_name: self.pred_xi, self.group_variable_name: np.ones(self.pred_xi.shape), } ) - (new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated) - self.pred_treated = self.model.predict(X=np.asarray(new_x)) + (new_x_treated,) = build_design_matrices([self._x_design_info], x_pred_treated) + self.pred_treated = self.model.predict(X=np.asarray(new_x_treated)) # Evaluate causal impact as equal to the trestment effect self.causal_impact = self.idata.posterior["beta"].sel( @@ -1079,7 +1082,7 @@ def __init__( def _input_validation(self) -> None: """Validate the input data and model formula for correctness""" - if not _series_has_2_levels(self.data[self.group_variable_name]): + if not _is_variable_dummy_coded(self.data[self.group_variable_name]): raise DataException( f""" There must be 2 levels of the grouping variable @@ -1165,7 +1168,7 @@ def _get_treatment_effect_coeff(self) -> str: then we want `C(group)[T.1]`. """ for label in self.labels: - if ("group" in label) & (":" not in label): + if (self.group_variable_name in label) & (":" not in label): return label raise NameError("Unable to find coefficient name for the treatment effect")