Skip to content

Commit

Permalink
FIX: support "passthrough" column transformer (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
j-ittner authored Mar 3, 2021
1 parent 418b1bd commit df7e597
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions src/sklearndf/transformation/wrapper/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from abc import ABCMeta, abstractmethod
from functools import reduce
from typing import Any, Generic, Iterable, List, Optional, TypeVar, Union
from typing import Any, Generic, List, Optional, TypeVar, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -251,10 +251,15 @@ class ColumnTransformerWrapperDF(
:class:`.TransformerDF`.
"""

__DROP = "drop"
__PASSTHROUGH = "passthrough"

__SPECIAL_TRANSFORMERS = (__DROP, __PASSTHROUGH)

def _validate_delegate_estimator(self) -> None:
column_transformer: ColumnTransformer = self.native_estimator

if column_transformer.remainder != "drop":
if column_transformer.remainder != ColumnTransformerWrapperDF.__DROP:
raise ValueError(
f"unsupported value for arg remainder: ({column_transformer.remainder})"
)
Expand All @@ -263,16 +268,18 @@ def _validate_delegate_estimator(self) -> None:
type(transformer).__name__
for _, transformer, _ in column_transformer.transformers
if not (
isinstance(transformer, str) or isinstance(transformer, TransformerDF)
isinstance(transformer, TransformerDF)
or transformer in ColumnTransformerWrapperDF.__SPECIAL_TRANSFORMERS
)
]
if non_compliant_transformers:
from .. import ColumnTransformerDF

raise ValueError(
f"{ColumnTransformerDF.__name__} only accepts strings or "
f"instances of "
f"{TransformerDF.__name__} as valid transformers, but "
f"{ColumnTransformerDF.__name__} only accepts instances of "
f"{TransformerDF.__name__} or special values "
f'"{" and ".join(ColumnTransformerWrapperDF.__SPECIAL_TRANSFORMERS)}" '
"as valid transformers, but "
f'also got: {", ".join(non_compliant_transformers)}'
)

Expand All @@ -283,22 +290,23 @@ def _get_features_original(self) -> pd.Series:
:return: the series with index the column names of the output dataframe and
values the corresponding input column names.
"""

return reduce(
lambda x, y: x.append(y),
(
df_transformer.feature_names_original_
for df_transformer in self._inner_transformers()
(
pd.Series(index=columns, data=columns)
if df_transformer == ColumnTransformerWrapperDF.__PASSTHROUGH
else df_transformer.feature_names_original_
)
for _, df_transformer, columns in self.native_estimator.transformers_
if (
len(columns) > 0
and df_transformer != ColumnTransformerWrapperDF.__DROP
)
),
)

def _inner_transformers(self) -> Iterable[TransformerWrapperDF]:
return (
df_transformer
for _, df_transformer, columns in self.native_estimator.transformers_
if len(columns) > 0
if df_transformer != "drop"
)


class ImputerWrapperDF(TransformerWrapperDF[T_Imputer], metaclass=ABCMeta):
"""
Expand Down

0 comments on commit df7e597

Please # to comment.