Skip to content

Commit

Permalink
Merge pull request #103 from Kurorororo/visualize-multiple-policy-values
Browse files Browse the repository at this point in the history
Add a method to visualize the policy values of multiple policies
  • Loading branch information
usaito authored Jun 6, 2021
2 parents 7f193e4 + 30dae82 commit 21d3edf
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions obp/ope/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,107 @@ def summarize_estimators_comparison(
index=[metric],
)
return eval_metric_ope_df.T

def visualize_off_policy_estimates_of_multiple_policies(
self,
policy_name_list: List[str],
action_dist_list: List[np.ndarray],
estimated_rewards_by_reg_model: Optional[
Union[np.ndarray, Dict[str, np.ndarray]]
] = None,
alpha: float = 0.05,
is_relative: bool = False,
n_bootstrap_samples: int = 100,
random_state: Optional[int] = None,
fig_dir: Optional[Path] = None,
fig_name: str = "estimated_policy_value.png",
) -> None:
"""Visualize policy values estimated by OPE estimators.
Parameters
----------
policy_name_list: List[str]
List of the names of policies.
action_dist_list: List[array-like, shape (n_rounds, n_actions, len_list)]
List of action choice probabilities by the evaluation policies (can be deterministic), i.e., :math:`\\pi_e(a_t|x_t)`.
estimated_rewards_by_reg_model: array-like, shape (n_rounds, n_actions, len_list) or Dict[str, array-like], default=None
Expected rewards for each round, action, and position estimated by a regression model, i.e., :math:`\\hat{q}(x_t,a_t)`.
When an array-like is given, all OPE estimators use it.
When a dict is given, if the dict has the name of an estimator as a key, the corresponding value is used.
When it is not given, model-dependent estimators such as DM and DR cannot be used.
alpha: float, default=0.05
Significant level of confidence intervals.
n_bootstrap_samples: int, default=100
Number of resampling performed in the bootstrap procedure.
random_state: int, default=None
Controls the random seed in bootstrap sampling.
is_relative: bool, default=False,
If True, the method visualizes the estimated policy values of evaluation policy
relative to the ground-truth policy value of behavior policy.
fig_dir: Path, default=None
Path to store the bar figure.
If 'None' is given, the figure will not be saved.
fig_name: str, default="estimated_policy_value.png"
Name of the bar figure.
"""
if len(policy_name_list) != len(action_dist_list):
raise ValueError(
"the length of policy_name_list must be the same as action_dist_list"
)
if fig_dir is not None:
assert isinstance(fig_dir, Path), "fig_dir must be a Path"
if fig_name is not None:
assert isinstance(fig_name, str), "fig_dir must be a string"

estimated_round_rewards_dict = {
estimator_name: {} for estimator_name in self.ope_estimators_
}

for policy_name, action_dist in zip(policy_name_list, action_dist_list):
estimator_inputs = self._create_estimator_inputs(
action_dist=action_dist,
estimated_rewards_by_reg_model=estimated_rewards_by_reg_model,
)
for estimator_name, estimator in self.ope_estimators_.items():
estimated_round_rewards_dict[estimator_name][
policy_name
] = estimator._estimate_round_rewards(
**estimator_inputs[estimator_name]
)

plt.style.use("ggplot")
fig = plt.figure(figsize=(8, 6.2 * len(self.ope_estimators_)))

for i, estimator_name in enumerate(self.ope_estimators_):
estimated_round_rewards_df = DataFrame(
estimated_round_rewards_dict[estimator_name]
)
if is_relative:
estimated_round_rewards_df /= self.bandit_feedback["reward"].mean()

ax = fig.add_subplot(len(action_dist_list), 1, i + 1)
sns.barplot(
data=estimated_round_rewards_df,
ax=ax,
ci=100 * (1 - alpha),
n_boot=n_bootstrap_samples,
seed=random_state,
)
ax.set_title(estimator_name.upper(), fontsize=20)
ax.set_ylabel(
f"Estimated Policy Value (± {np.int(100*(1 - alpha))}% CI)", fontsize=20
)
plt.yticks(fontsize=15)
plt.xticks(fontsize=25 - 2 * len(policy_name_list))

if fig_dir:
fig.savefig(str(fig_dir / fig_name))

0 comments on commit 21d3edf

Please # to comment.