Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add a method to visualize the policy values of multiple policies #103

Merged
merged 3 commits into from
Jun 6, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions obp/ope/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,107 @@ def summarize_estimators_comparison(
index=[metric],
)
return eval_metric_ope_df.T

def visualize_off_policy_estimates_of_multiple_policies(
self,
policy_name_list: List[str],
action_dist_list: List[np.ndarray],
estimated_rewards_by_reg_model: Optional[
Union[np.ndarray, Dict[str, np.ndarray]]
] = None,
alpha: float = 0.05,
is_relative: bool = False,
n_bootstrap_samples: int = 100,
random_state: Optional[int] = None,
fig_dir: Optional[Path] = None,
fig_name: str = "estimated_policy_value.png",
) -> None:
"""Visualize policy values estimated by OPE estimators.

Parameters
----------
policy_name_list: List[str]
List of the names of policies.

action_dist_list: List[array-like, shape (n_rounds, n_actions, len_list)]
List of action choice probabilities by the evaluation policies (can be deterministic), i.e., :math:`\\pi_e(a_t|x_t)`.

estimated_rewards_by_reg_model: array-like, shape (n_rounds, n_actions, len_list) or Dict[str, array-like], default=None
Expected rewards for each round, action, and position estimated by a regression model, i.e., :math:`\\hat{q}(x_t,a_t)`.
When an array-like is given, all OPE estimators use it.
When a dict is given, if the dict has the name of an estimator as a key, the corresponding value is used.
When it is not given, model-dependent estimators such as DM and DR cannot be used.

alpha: float, default=0.05
Significant level of confidence intervals.

n_bootstrap_samples: int, default=100
Number of resampling performed in the bootstrap procedure.

random_state: int, default=None
Controls the random seed in bootstrap sampling.

is_relative: bool, default=False,
If True, the method visualizes the estimated policy values of evaluation policy
relative to the ground-truth policy value of behavior policy.

fig_dir: Path, default=None
Path to store the bar figure.
If 'None' is given, the figure will not be saved.

fig_name: str, default="estimated_policy_value.png"
Name of the bar figure.

"""
if len(policy_name_list) != len(action_dist_list):
raise ValueError(
"the length of policy_name_list must be the same as action_dist_list"
)
if fig_dir is not None:
assert isinstance(fig_dir, Path), "fig_dir must be a Path"
if fig_name is not None:
assert isinstance(fig_name, str), "fig_dir must be a string"

estimated_round_rewards_dict = {
estimator_name: {} for estimator_name in self.ope_estimators_
}

for policy_name, action_dist in zip(policy_name_list, action_dist_list):
estimator_inputs = self._create_estimator_inputs(
action_dist=action_dist,
estimated_rewards_by_reg_model=estimated_rewards_by_reg_model,
)
for estimator_name, estimator in self.ope_estimators_.items():
estimated_round_rewards_dict[estimator_name][
policy_name
] = estimator._estimate_round_rewards(
**estimator_inputs[estimator_name]
)

plt.style.use("ggplot")
fig = plt.figure(figsize=(8, 6.2 * len(self.ope_estimators_)))

for i, estimator_name in enumerate(self.ope_estimators_):
estimated_round_rewards_df = DataFrame(
estimated_round_rewards_dict[estimator_name]
)
if is_relative:
estimated_round_rewards_df /= self.bandit_feedback["reward"].mean()

ax = fig.add_subplot(len(action_dist_list), 1, i + 1)
sns.barplot(
data=estimated_round_rewards_df,
ax=ax,
ci=100 * (1 - alpha),
n_boot=n_bootstrap_samples,
seed=random_state,
)
ax.set_title(estimator_name.upper(), fontsize=20)
ax.set_ylabel(
f"Estimated Policy Value (± {np.int(100*(1 - alpha))}% CI)", fontsize=20
)
plt.yticks(fontsize=15)
plt.xticks(fontsize=25 - 2 * len(policy_name_list))

if fig_dir:
fig.savefig(str(fig_dir / fig_name))