From 4ef8348faaca3c65609da0fa1a4009e27de0f0f6 Mon Sep 17 00:00:00 2001 From: Janis Klaise Date: Mon, 5 Sep 2022 11:16:16 +0100 Subject: [PATCH] Fix edge case in AnchorTabular where no samples satisfying the anchor exist in the train data --- alibi/explainers/anchors/anchor_tabular.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/alibi/explainers/anchors/anchor_tabular.py b/alibi/explainers/anchors/anchor_tabular.py index 5405ab7b8..d6d4d7dd0 100644 --- a/alibi/explainers/anchors/anchor_tabular.py +++ b/alibi/explainers/anchors/anchor_tabular.py @@ -290,6 +290,10 @@ def perturbation(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np. [allowed_rows[feat] for feat in uniq_feat_ids], np.intersect1d), ) + if partial_anchor_rows == []: + # edge case - if there are no rows at all then `partial_anchor_rows` is the empty list, but it should + # be a list of an empty array to not cause an error in calculating coverage (which will be 0) + partial_anchor_rows = [np.array([], dtype=int)] nb_partial_anchors = np.array([len(n_records) for n_records in reversed(partial_anchor_rows)]) # reverse required for np.searchsorted later coverage = nb_partial_anchors[0] / self.n_records # since we sorted, the correct coverage is first not last @@ -383,7 +387,14 @@ def replace_features(self, samples: np.ndarray, allowed_rows: Dict[int, Any], un requested_samples = num_samples start, n_anchor_feats = 0, len(partial_anchor_rows) uniq_feat_ids = list(reversed(uniq_feat_ids)) - start_idx = np.nonzero(nb_partial_anchors)[0][0] # skip anchors with no samples in the database + + try: + start_idx = np.nonzero(nb_partial_anchors)[0][0] # skip anchors with no samples in the database + except IndexError: + # there are no samples in the database, need to break out of the function + # and go straight to treating unknown features + return + end_idx = np.searchsorted(np.cumsum(nb_partial_anchors), num_samples) # replace partial anchors with partial anchors drawn from the training dataset