Skip to content

Commit 56f04e0

Browse files
authored
fix types multidim scaling (#646)
* fix types multidim scaling * fix types * add error type
1 parent a3d9023 commit 56f04e0

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

alibi/explainers/cfproto.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self,
3838
shape: tuple,
3939
kappa: float = 0.,
4040
beta: float = .1,
41-
feature_range: tuple = (-1e10, 1e10),
41+
feature_range: Tuple[Union[float, np.ndarray], Union[float, np.ndarray]] = (-1e10, 1e10),
4242
gamma: float = 0.,
4343
ae_model: Optional[tf.keras.Model] = None,
4444
enc_model: Optional[tf.keras.Model] = None,
@@ -178,7 +178,9 @@ def __init__(self,
178178
self.max_iterations = max_iterations
179179
self.c_init = c_init
180180
self.c_steps = c_steps
181-
self.feature_range = feature_range
181+
self.feature_range = tuple([(np.ones(shape[1:]) * feature_range[_])[None, :]
182+
if isinstance(feature_range[_], float) else feature_range[_]
183+
for _ in range(2)])
182184
self.update_num_grad = update_num_grad
183185
self.eps = eps
184186
self.clip = clip
@@ -754,13 +756,13 @@ def fit(self,
754756

755757
# multidim scaled distances
756758
d_abs_abdm, _ = multidim_scaling(d_abdm, n_components=2, use_metric=True,
757-
feature_range=self.feature_range,
759+
feature_range=self.feature_range, # type: ignore[arg-type]
758760
standardize_cat_vars=standardize_cat_vars,
759761
smooth=smooth, center=center,
760762
update_feature_range=False)
761763

762764
d_abs_mvdm, _ = multidim_scaling(d_mvdm, n_components=2, use_metric=True,
763-
feature_range=self.feature_range,
765+
feature_range=self.feature_range, # type: ignore[arg-type]
764766
standardize_cat_vars=standardize_cat_vars,
765767
smooth=smooth, center=center,
766768
update_feature_range=False)
@@ -779,7 +781,7 @@ def fit(self,
779781
self.feature_range = new_feature_range
780782
else: # apply multidimensional scaling for the abdm or mvdm distances
781783
self.d_abs, self.feature_range = multidim_scaling(d_pair, n_components=2, use_metric=True,
782-
feature_range=self.feature_range,
784+
feature_range=self.feature_range, # type: ignore
783785
standardize_cat_vars=standardize_cat_vars,
784786
smooth=smooth, center=center,
785787
update_feature_range=update_feature_range)

alibi/utils/distance.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def abdm(X: np.ndarray,
162162

163163

164164
def multidim_scaling(d_pair: dict,
165-
feature_range: tuple,
165+
feature_range: Tuple[np.ndarray, np.ndarray],
166166
n_components: int = 2,
167167
use_metric: bool = True,
168168
standardize_cat_vars: bool = True,
@@ -178,8 +178,8 @@ def multidim_scaling(d_pair: dict,
178178
Dict with as keys the column index of the categorical variables and as values
179179
a pairwise distance matrix for the categories of the variable.
180180
feature_range
181-
Tuple with `min` and `max` ranges to allow for perturbed instances. `Min` and `max` ranges can be `float` or
182-
`numpy` arrays with dimension (`1 x nb of features`) for feature-wise ranges.
181+
Tuple with `min` and `max` ranges to allow for perturbed instances. `Min` and `max` ranges are
182+
`numpy` arrays with dimension (`1 x nb of features`).
183183
n_components
184184
Number of dimensions in which to immerse the dissimilarities.
185185
use_metric
@@ -240,6 +240,6 @@ def multidim_scaling(d_pair: dict,
240240
d_abs_scaled[k] = d_scaled # scaled distance from the origin for each category
241241

242242
if update_feature_range:
243-
feature_range = new_feature_range
243+
feature_range = new_feature_range # type: ignore
244244

245245
return d_abs_scaled, feature_range

0 commit comments

Comments
 (0)