Skip to content

Commit

Permalink
Added a non_categorical_mode argument to qd_screen and to `get_ca…
Browse files Browse the repository at this point in the history
…tegorical_features` and added input validators. Fixed #37
  • Loading branch information
Sylvain MARIE committed Mar 16, 2023
1 parent eb09583 commit 996f1f7
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 14 deletions.
70 changes: 59 additions & 11 deletions qdscreen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,30 @@ def plot_increasing_entropies(self):
self.stats.plot_increasing_entropies()


def qd_screen(X, # type: Union[pd.DataFrame, np.ndarray]
def assert_df_or_2D_array(df_or_array # type: Union[pd.DataFrame, np.ndarray]
):
"""
Raises a ValueError if `df_or_array` is
:param df_or_array:
:return:
"""
if isinstance(df_or_array, pd.DataFrame):
pass
elif isinstance(df_or_array, np.ndarray):
# see https://numpy.org/doc/stable/user/basics.rec.html#manipulating-and-displaying-structured-datatypes
if len(df_or_array.shape) != 2:
raise ValueError("Provided data is not a 2D array, the number of dimensions is %s" % len(df_or_array.shape))
else:
# Raise error
raise TypeError("Provided data is neither a `pd.DataFrame` nor a `np.ndarray`")


def qd_screen(X, # type: Union[pd.DataFrame, np.ndarray]
absolute_eps=None, # type: float
relative_eps=None, # type: float
keep_stats=False # type: bool
keep_stats=False, # type: bool
non_categorical_mode='strict',
):
# type: (...) -> QDForest
"""
Expand Down Expand Up @@ -575,12 +595,18 @@ def qd_screen(X, # type: Union[pd.DataFrame, np.ndarray]
memory in the resulting forest object (`<QDForest>.stats`), for further analysis. By default this is `False`.
:return:
"""
# only work on the categorical features
X = get_categorical_features(X)
# Make sure this is a 2D table
assert_df_or_2D_array(X)

# sanity check
# Sanity check: are there rows in here ?
if len(X) == 0:
raise ValueError("Empty dataset provided")
raise ValueError("Provided dataset does not contain any row")

# Only work on the categorical features
X = get_categorical_features(X, non_categorical_mode=non_categorical_mode)

# Sanity check concerning the number of columns
assert X.shape[1] > 0, "Internal error: no columns remain in dataset after preprocessing."

# parameters check and defaults
if absolute_eps is None:
Expand Down Expand Up @@ -1144,28 +1170,49 @@ def get_arcs_from_adjmat(A, # type: Union[np.ndarray, pd.DataFra
return ((cols[i], cols[j]) for i, j in zip(*res_ar))


def get_categorical_features(df_or_array # type: Union[np.ndarray, pd.DataFrame]
def get_categorical_features(df_or_array, # type: Union[np.ndarray, pd.DataFrame]
non_categorical_mode="strict" # type: str
):
# type: (...) -> Union[np.ndarray, pd.DataFrame]
"""
:param df_or_array:
:param non_categorical_mode:
:return: a dataframe or array with the categorical features
"""
assert_df_or_2D_array(df_or_array)

if non_categorical_mode == "strict":
strict_mode = True
elif non_categorical_mode == "remove":
strict_mode = False
else:
raise ValueError("Unsupported value for `non_categorical_mode`: %r" % non_categorical_mode)

if isinstance(df_or_array, pd.DataFrame):
is_categorical_dtype = df_or_array.dtypes.astype(str).isin(["object", "categorical"])
if not is_categorical_dtype.any():
raise TypeError("Provided dataframe columns do not contain any categorical datatype (dtype in 'object' or "
if strict_mode and not is_categorical_dtype.all():
raise ValueError("Provided dataframe columns contains non-categorical datatypes (dtype in 'object' or "
"'categorical'): found dtypes %r. This is not supported when `non_categorical_mode` is set to "
"`'strict'`" % df_or_array.dtypes[~is_categorical_dtype].to_dict())
elif not is_categorical_dtype.any():
raise ValueError("Provided dataframe columns do not contain any categorical datatype (dtype in 'object' or "
"'categorical'): found dtypes %r" % df_or_array.dtypes[~is_categorical_dtype].to_dict())
return df_or_array.loc[:, is_categorical_dtype]

elif isinstance(df_or_array, np.ndarray):
# see https://numpy.org/doc/stable/user/basics.rec.html#manipulating-and-displaying-structured-datatypes
if df_or_array.dtype.names is not None:
# structured array
is_categorical_dtype = np.array([str(df_or_array.dtype.fields[n][0]) == "object"
for n in df_or_array.dtype.names])
if not is_categorical_dtype.any():
raise TypeError(
if strict_mode and not is_categorical_dtype.all():
invalid_dtypes = df_or_array.dtype[~is_categorical_dtype].asdict()
raise ValueError("Provided numpy array columns contains non-categorical datatypes ('object' dtype): "
"found dtypes %r. This is not supported when `non_categorical_mode` is set to "
"`'strict'`" % invalid_dtypes)
elif not is_categorical_dtype.any():
raise ValueError(
"Provided dataframe columns do not contain any categorical datatype (dtype in 'object' or "
"'categorical'): found dtypes %r" % df_or_array.dtype.fields)
categorical_names = np.array(df_or_array.dtype.names)[is_categorical_dtype]
Expand All @@ -1177,6 +1224,7 @@ def get_categorical_features(df_or_array # type: Union[np.ndarray, pd.DataFrame
% df_or_array.dtype)
return df_or_array
else:
# Should not happen since `assert_df_or_2D_array` is called upfront now.
raise TypeError("Provided data is neither a pd.DataFrame nor a np.ndarray")


Expand Down
47 changes: 44 additions & 3 deletions qdscreen/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from .main import QDForest


class InvalidDataInputError(ValueError):
"""Raised when input data is invalid"""


def _get_most_common_value(x):
# From https://stackoverflow.com/a/47778607/7262247
# `scipy_mode` is the most robust to the various pitfalls (nans, ...)
Expand All @@ -36,12 +40,47 @@ def __init__(self,
self.forest = qd_forest
self._maps = None # type: Optional[Dict[Any, Dict[Any, Dict]]]

def fit(self,
X # type: Union[np.ndarray, pd.DataFrame]
):
def assert_valid_input(
self,
X, # type: Union[np.ndarray, pd.DataFrame]
df_extras_allowed=False # type: bool
):
"""Raises an InvalidDataInputError if X does not match the expectation"""

if self.forest.is_nparray:
if not isinstance(X, np.ndarray):
raise InvalidDataInputError(
"Input data must be an numpy array. Found: %s" % type(X))

if X.shape[1] != self.forest.nb_vars: # or X.shape[0] != X.shape[1]:
raise InvalidDataInputError(
"Input numpy array must have %s columns. Found %s columns" % (self.forest.nb_vars, X.shape[1]))
else:
if not isinstance(X, pd.DataFrame):
raise InvalidDataInputError(
"Input data must be a pandas DataFrame. Found: %s" % type(X))

actual = set(X.columns)
expected = set(self.forest.varnames)
if actual != expected:
missing = expected - actual
if missing or not df_extras_allowed:
extra = actual - expected
raise InvalidDataInputError(
"Input pandas DataFrame must have column names matching the ones in the model. "
"Missing: %s. Extra: %s " % (missing, extra)
)

def fit(
self,
X # type: Union[np.ndarray, pd.DataFrame]
):
"""Fits the maps able to predict determined features from others"""
forest = self.forest

# Validate the input
self.assert_valid_input(X, df_extras_allowed=False)

# we will create a sparse coordinate representation of maps
n = forest.nb_vars

Expand Down Expand Up @@ -118,6 +157,8 @@ def remove_qd(self,
"""
forest = self.forest

self.assert_valid_input(X, df_extras_allowed=True)

is_x_nparray = isinstance(X, np.ndarray)
assert is_x_nparray == forest.is_nparray

Expand Down

0 comments on commit 996f1f7

Please # to comment.