@@ -38,7 +38,7 @@ def __init__(self,
38
38
shape : tuple ,
39
39
kappa : float = 0. ,
40
40
beta : float = .1 ,
41
- feature_range : tuple = (- 1e10 , 1e10 ),
41
+ feature_range : Tuple [ Union [ float , np . ndarray ], Union [ float , np . ndarray ]] = (- 1e10 , 1e10 ),
42
42
gamma : float = 0. ,
43
43
ae_model : Optional [tf .keras .Model ] = None ,
44
44
enc_model : Optional [tf .keras .Model ] = None ,
@@ -178,7 +178,9 @@ def __init__(self,
178
178
self .max_iterations = max_iterations
179
179
self .c_init = c_init
180
180
self .c_steps = c_steps
181
- self .feature_range = feature_range
181
+ self .feature_range = tuple ([(np .ones (shape [1 :]) * feature_range [_ ])[None , :]
182
+ if isinstance (feature_range [_ ], float ) else feature_range [_ ]
183
+ for _ in range (2 )])
182
184
self .update_num_grad = update_num_grad
183
185
self .eps = eps
184
186
self .clip = clip
@@ -754,13 +756,13 @@ def fit(self,
754
756
755
757
# multidim scaled distances
756
758
d_abs_abdm , _ = multidim_scaling (d_abdm , n_components = 2 , use_metric = True ,
757
- feature_range = self .feature_range ,
759
+ feature_range = self .feature_range , # type: ignore[arg-type]
758
760
standardize_cat_vars = standardize_cat_vars ,
759
761
smooth = smooth , center = center ,
760
762
update_feature_range = False )
761
763
762
764
d_abs_mvdm , _ = multidim_scaling (d_mvdm , n_components = 2 , use_metric = True ,
763
- feature_range = self .feature_range ,
765
+ feature_range = self .feature_range , # type: ignore[arg-type]
764
766
standardize_cat_vars = standardize_cat_vars ,
765
767
smooth = smooth , center = center ,
766
768
update_feature_range = False )
@@ -779,7 +781,7 @@ def fit(self,
779
781
self .feature_range = new_feature_range
780
782
else : # apply multidimensional scaling for the abdm or mvdm distances
781
783
self .d_abs , self .feature_range = multidim_scaling (d_pair , n_components = 2 , use_metric = True ,
782
- feature_range = self .feature_range ,
784
+ feature_range = self .feature_range , # type: ignore
783
785
standardize_cat_vars = standardize_cat_vars ,
784
786
smooth = smooth , center = center ,
785
787
update_feature_range = update_feature_range )
0 commit comments