Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix edge case in AnchorTabular where no samples satisfying the anchor… #742

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion alibi/explainers/anchors/anchor_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def perturbation(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.
[allowed_rows[feat] for feat in uniq_feat_ids],
np.intersect1d),
)
if partial_anchor_rows == []:
# edge case - if there are no rows at all then `partial_anchor_rows` is the empty list, but it should
# be a list of an empty array to not cause an error in calculating coverage (which will be 0)
partial_anchor_rows = [np.array([], dtype=int)]
nb_partial_anchors = np.array([len(n_records) for n_records in
reversed(partial_anchor_rows)]) # reverse required for np.searchsorted later
coverage = nb_partial_anchors[0] / self.n_records # since we sorted, the correct coverage is first not last
Expand Down Expand Up @@ -383,7 +387,14 @@ def replace_features(self, samples: np.ndarray, allowed_rows: Dict[int, Any], un
requested_samples = num_samples
start, n_anchor_feats = 0, len(partial_anchor_rows)
uniq_feat_ids = list(reversed(uniq_feat_ids))
start_idx = np.nonzero(nb_partial_anchors)[0][0] # skip anchors with no samples in the database

try:
start_idx = np.nonzero(nb_partial_anchors)[0][0] # skip anchors with no samples in the database
except IndexError:
# there are no samples in the database, need to break out of the function
# and go straight to treating unknown features
return

end_idx = np.searchsorted(np.cumsum(nb_partial_anchors), num_samples)

# replace partial anchors with partial anchors drawn from the training dataset
Expand Down