Skip to content

Commit 466bc18

Browse files
[ADD] variance thresholding (#373)
* add variance thresholding * fix flake and mypy * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent bd4fabf commit 466bc18

File tree

6 files changed

+102
-0
lines changed

6 files changed

+102
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Dict, Optional, Union
2+
3+
import numpy as np
4+
5+
from sklearn.feature_selection import VarianceThreshold as SklearnVarianceThreshold
6+
7+
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
8+
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \
9+
autoPyTorchTabularPreprocessingComponent
10+
11+
12+
class VarianceThreshold(autoPyTorchTabularPreprocessingComponent):
13+
"""
14+
Removes features that have the same value in the training data.
15+
"""
16+
def __init__(self, random_state: Optional[np.random.RandomState] = None):
17+
super().__init__()
18+
19+
def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> 'VarianceThreshold':
20+
21+
self.check_requirements(X, y)
22+
23+
self.preprocessor['numerical'] = SklearnVarianceThreshold(
24+
threshold=0.0
25+
)
26+
return self
27+
28+
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
29+
if self.preprocessor['numerical'] is None:
30+
raise ValueError("cannot call transform on {} without fitting first."
31+
.format(self.__class__.__name__))
32+
X.update({'variance_threshold': self.preprocessor})
33+
return X
34+
35+
@staticmethod
36+
def get_properties(
37+
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None
38+
) -> Dict[str, Union[str, bool]]:
39+
40+
return {
41+
'shortname': 'Variance Threshold',
42+
'name': 'Variance Threshold (constant feature removal)',
43+
'handles_sparse': True,
44+
}

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/variance_thresholding/__init__.py

Whitespace-only changes.

autoPyTorch/pipeline/tabular_classification.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
)
2828
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer
2929
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling import ScalerChoice
30+
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding. \
31+
VarianceThreshold import VarianceThreshold
3032
from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing
3133
from autoPyTorch.pipeline.components.setup.lr_scheduler import SchedulerChoice
3234
from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent
@@ -307,6 +309,7 @@ def _get_pipeline_steps(
307309

308310
steps.extend([
309311
("imputer", SimpleImputer(random_state=self.random_state)),
312+
("variance_threshold", VarianceThreshold(random_state=self.random_state)),
310313
("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)),
311314
("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)),
312315
("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties,

autoPyTorch/pipeline/tabular_regression.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
)
2828
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer
2929
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling import ScalerChoice
30+
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding. \
31+
VarianceThreshold import VarianceThreshold
3032
from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing
3133
from autoPyTorch.pipeline.components.setup.lr_scheduler import SchedulerChoice
3234
from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent
@@ -257,6 +259,7 @@ def _get_pipeline_steps(
257259

258260
steps.extend([
259261
("imputer", SimpleImputer(random_state=self.random_state)),
262+
("variance_threshold", VarianceThreshold(random_state=self.random_state)),
260263
("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)),
261264
("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)),
262265
("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties,

test/test_pipeline/components/preprocessing/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import EncoderChoice
77
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer
88
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling import ScalerChoice
9+
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding. \
10+
VarianceThreshold import VarianceThreshold
911
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
1012

1113

@@ -28,6 +30,7 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]],
2830

2931
steps.extend([
3032
("imputer", SimpleImputer()),
33+
("variance_threshold", VarianceThreshold()),
3134
("encoder", EncoderChoice(default_dataset_properties)),
3235
("scaler", ScalerChoice(default_dataset_properties)),
3336
("tabular_transformer", TabularColumnTransformer()),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
from numpy.testing import assert_array_equal
3+
4+
5+
from sklearn.base import BaseEstimator
6+
from sklearn.compose import make_column_transformer
7+
8+
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding. \
9+
VarianceThreshold import VarianceThreshold
10+
11+
12+
def test_variance_threshold():
13+
data = np.array([[1, 2, 1],
14+
[7, 8, 9],
15+
[4, 5, 1],
16+
[11, 12, 1],
17+
[17, 18, 19],
18+
[14, 15, 16]])
19+
numerical_columns = [0, 1, 2]
20+
train_indices = np.array([0, 2, 3])
21+
test_indices = np.array([1, 4, 5])
22+
dataset_properties = {
23+
'categorical_columns': [],
24+
'numerical_columns': numerical_columns,
25+
}
26+
X = {
27+
'X_train': data[train_indices],
28+
'dataset_properties': dataset_properties
29+
}
30+
component = VarianceThreshold()
31+
32+
component = component.fit(X)
33+
X = component.transform(X)
34+
variance_threshold = X['variance_threshold']['numerical']
35+
36+
# check if the fit dictionary X is modified as expected
37+
assert isinstance(X['variance_threshold'], dict)
38+
assert isinstance(variance_threshold, BaseEstimator)
39+
40+
# make column transformer with returned encoder to fit on data
41+
column_transformer = make_column_transformer((variance_threshold,
42+
X['dataset_properties']['numerical_columns']),
43+
remainder='passthrough')
44+
column_transformer = column_transformer.fit(X['X_train'])
45+
transformed = column_transformer.transform(data[test_indices])
46+
47+
assert_array_equal(transformed, np.array([[7, 8],
48+
[17, 18],
49+
[14, 15]]))

0 commit comments

Comments
 (0)