Skip to content

Commit 91bcd37

Browse files
committed
[fix] [test] Adapt the modification of targets to scipy.sparse.xxx_matrix
1 parent e8d7685 commit 91bcd37

File tree

5 files changed

+48
-59
lines changed

5 files changed

+48
-59
lines changed

autoPyTorch/data/base_feature_validator.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,13 @@
55

66
import pandas as pd
77

8-
import scipy.sparse
9-
108
from sklearn.base import BaseEstimator
119

10+
from autoPyTorch.utils.common import SparseMatrixType
1211
from autoPyTorch.utils.logging_ import PicklableClientLogger
1312

1413

15-
SupportedFeatTypes = Union[
16-
List,
17-
pd.DataFrame,
18-
np.ndarray,
19-
scipy.sparse.bsr_matrix,
20-
scipy.sparse.coo_matrix,
21-
scipy.sparse.csc_matrix,
22-
scipy.sparse.csr_matrix,
23-
scipy.sparse.dia_matrix,
24-
scipy.sparse.dok_matrix,
25-
scipy.sparse.lil_matrix,
26-
]
14+
SupportedFeatTypes = Union[List, pd.DataFrame, np.ndarray, SparseMatrixType]
2715

2816

2917
class BaseFeatureValidator(BaseEstimator):

autoPyTorch/data/base_target_validator.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,13 @@
55

66
import pandas as pd
77

8-
import scipy.sparse
9-
108
from sklearn.base import BaseEstimator
119

10+
from autoPyTorch.utils.common import SparseMatrixType
1211
from autoPyTorch.utils.logging_ import PicklableClientLogger
1312

1413

15-
SupportedTargetTypes = Union[
16-
List,
17-
pd.Series,
18-
pd.DataFrame,
19-
np.ndarray,
20-
scipy.sparse.bsr_matrix,
21-
scipy.sparse.coo_matrix,
22-
scipy.sparse.csc_matrix,
23-
scipy.sparse.csr_matrix,
24-
scipy.sparse.dia_matrix,
25-
scipy.sparse.dok_matrix,
26-
scipy.sparse.lil_matrix,
27-
]
14+
SupportedTargetTypes = Union[List, pd.Series, pd.DataFrame, np.ndarray, SparseMatrixType]
2815

2916

3017
class BaseTargetValidator(BaseEstimator):

autoPyTorch/data/tabular_target_validator.py

+32-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, cast
1+
from typing import List, Optional, Union, cast
22

33
import numpy as np
44

@@ -14,13 +14,37 @@
1414
from sklearn.utils.multiclass import type_of_target
1515

1616
from autoPyTorch.data.base_target_validator import BaseTargetValidator, SupportedTargetTypes
17+
from autoPyTorch.utils.common import SparseMatrixType
1718

1819

19-
def _check_and_to_numpy(y: SupportedTargetTypes) -> np.ndarray:
20+
ArrayType = Union[np.ndarray, SparseMatrixType]
21+
22+
23+
def _check_and_to_array(y: SupportedTargetTypes) -> ArrayType:
2024
""" sklearn check array will make sure we have the correct numerical features for the array """
2125
return sklearn.utils.check_array(y, force_all_finite=True, accept_sparse='csr', ensure_2d=False)
2226

2327

28+
def _modify_regression_target(y: ArrayType) -> ArrayType:
29+
# Regression targets must have numbers after a decimal point.
30+
# Ref: https://github.com/scikit-learn/scikit-learn/issues/8952
31+
y_min = np.abs(y).min()
32+
offset = y_min * 1e-16 # Sufficiently small number
33+
if y_min > 1e15:
34+
raise ValueError(
35+
"The minimum value for the target labels of regression tasks must be smaller than "
36+
f"1e15 to avoid errors caused by an overflow, but got {y_min}"
37+
)
38+
39+
# Since it is all integer, we can just add a random small number
40+
if isinstance(y, np.ndarray):
41+
y = y.astype(dtype=np.float64) + offset
42+
else:
43+
y.data = y.data.astype(dtype=np.float64) + offset
44+
45+
return y
46+
47+
2448
class TabularTargetValidator(BaseTargetValidator):
2549
def _fit(
2650
self,
@@ -101,7 +125,7 @@ def _fit(
101125

102126
def _transform_by_encoder(self, y: SupportedTargetTypes) -> np.ndarray:
103127
if self.encoder is None:
104-
return _check_and_to_numpy(y)
128+
return _check_and_to_array(y)
105129

106130
# remove ravel warning from pandas Series
107131
shape = np.shape(y)
@@ -115,12 +139,9 @@ def _transform_by_encoder(self, y: SupportedTargetTypes) -> np.ndarray:
115139
else:
116140
y = self.encoder.transform(np.array(y).reshape(-1, 1)).reshape(-1)
117141

118-
return _check_and_to_numpy(y)
142+
return _check_and_to_array(y)
119143

120-
def transform(
121-
self,
122-
y: SupportedTargetTypes,
123-
) -> np.ndarray:
144+
def transform(self, y: SupportedTargetTypes) -> np.ndarray:
124145
"""
125146
Validates and fit a categorical encoder (if needed) to the features.
126147
The supported data types are List, numpy arrays and pandas DataFrames.
@@ -146,24 +167,11 @@ def transform(
146167
y = np.ravel(y)
147168

148169
if not self.is_classification and "continuous" not in type_of_target(y):
149-
# Regression targets must have numbers after a decimal point.
150-
# Ref: https://github.com/scikit-learn/scikit-learn/issues/8952
151-
y_min = np.abs(y).min()
152-
offset = y_min * 1e-16 # Sufficiently small number
153-
if y_min > 1e15:
154-
raise ValueError(
155-
"The minimum value for the target labels of regression tasks must be smaller than "
156-
f"1e15 to avoid errors caused by an overflow, but got {y_min}"
157-
)
158-
159-
y = y.astype(dtype=np.float64) + offset # Since it is all integer, we can just add a random small number
170+
y = _modify_regression_target(y)
160171

161172
return y
162173

163-
def inverse_transform(
164-
self,
165-
y: SupportedTargetTypes,
166-
) -> np.ndarray:
174+
def inverse_transform(self, y: SupportedTargetTypes) -> np.ndarray:
167175
"""
168176
Revert any encoding transformation done on a target array
169177
@@ -197,10 +205,7 @@ def inverse_transform(
197205
y = y.astype(self.dtype)
198206
return y
199207

200-
def _check_data(
201-
self,
202-
y: SupportedTargetTypes,
203-
) -> None:
208+
def _check_data(self, y: SupportedTargetTypes) -> None:
204209
"""
205210
Perform dimensionality and data type checks on the targets
206211

autoPyTorch/utils/common.py

+9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@
2020
from torch.utils.data.dataloader import default_collate
2121

2222
HyperparameterValueType = Union[int, str, float]
23+
SparseMatrixType = Union[
24+
scipy.sparse.bsr_matrix,
25+
scipy.sparse.coo_matrix,
26+
scipy.sparse.csc_matrix,
27+
scipy.sparse.csr_matrix,
28+
scipy.sparse.dia_matrix,
29+
scipy.sparse.dok_matrix,
30+
scipy.sparse.lil_matrix,
31+
]
2332

2433

2534
class FitRequirement(NamedTuple):

test/test_data/test_target_validator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -150,17 +150,17 @@ def test_targetvalidator_supported_types_noclassification(input_data_targettest)
150150
assert validator.encoder is None
151151

152152
if hasattr(input_data_targettest, "iloc"):
153-
np.testing.assert_array_equal(
153+
assert np.allclose(
154154
np.ravel(input_data_targettest.to_numpy()),
155155
np.ravel(transformed_y)
156156
)
157157
elif sparse.issparse(input_data_targettest):
158-
np.testing.assert_array_equal(
158+
assert np.allclose(
159159
np.ravel(input_data_targettest.todense()),
160160
np.ravel(transformed_y.todense())
161161
)
162162
else:
163-
np.testing.assert_array_equal(
163+
assert np.allclose(
164164
np.ravel(np.array(input_data_targettest)),
165165
np.ravel(transformed_y)
166166
)

0 commit comments

Comments
 (0)