diff --git a/sklearn_questions.py b/sklearn_questions.py index fa02e0d..9759b09 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -47,6 +47,7 @@ to compute distances between 2 sets of samples. """ + import numpy as np import pandas as pd @@ -54,17 +55,20 @@ from sklearn.base import ClassifierMixin from sklearn.model_selection import BaseCrossValidator +from sklearn.preprocessing import LabelEncoder + from sklearn.utils.validation import check_X_y, check_is_fitted -from sklearn.utils.validation import check_array -from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.validation import validate_data +from sklearn.utils.multiclass import unique_labels from sklearn.metrics.pairwise import pairwise_distances -class KNearestNeighbors(BaseEstimator, ClassifierMixin): +class KNearestNeighbors(ClassifierMixin, BaseEstimator): """KNearestNeighbors classifier.""" - def __init__(self, n_neighbors=1): # noqa: D107 + def __init__(self, n_neighbors=1): + """Fitting function.Dummy.""" self.n_neighbors = n_neighbors def fit(self, X, y): @@ -82,6 +86,15 @@ def fit(self, X, y): self : instance of KNearestNeighbors The current instance of the classifier """ + self.classes_ = unique_labels(y) + X, y = validate_data(self, X, y, reset=True) + + self.label_encoder_ = LabelEncoder() + self.X_ = X + + self.y_ = self.label_encoder_.fit_transform(y) + self.is_fitted_ = True + return self def predict(self, X): @@ -97,7 +110,22 @@ def predict(self, X): y : ndarray, shape (n_test_samples,) Predicted class labels for each test data sample. """ - y_pred = np.zeros(X.shape[0]) + check_is_fitted(self) + X = validate_data(self, X, reset=False, dtype=float) + + y_pred = np.zeros(X.shape[0], dtype=int) + + distance_mat = pairwise_distances(X, self.X_).argsort(axis=1) + + index_min_dist = distance_mat[:, : self.n_neighbors] + + for ind, row in enumerate(index_min_dist): + val = self.y_[row] + nearest_neigh = np.bincount(val).argmax() + y_pred[ind] = nearest_neigh + + y_pred = self.label_encoder_.inverse_transform(y_pred) + return y_pred def score(self, X, y): @@ -115,7 +143,11 @@ def score(self, X, y): score : float Accuracy of the model computed for the (X, y) pairs. """ - return 0. + X, y = check_X_y(X, y) + + y_pred = self.predict(X) + acc = (y_pred == y).sum() / len(y) + return acc class MonthlySplit(BaseCrossValidator): @@ -134,7 +166,8 @@ class MonthlySplit(BaseCrossValidator): To use the index as column just set `time_col` to `'index'`. """ - def __init__(self, time_col='index'): # noqa: D107 + def __init__(self, time_col="index"): + """Fitting function.Dummy.""" self.time_col = time_col def get_n_splits(self, X, y=None, groups=None): @@ -155,7 +188,17 @@ def get_n_splits(self, X, y=None, groups=None): n_splits : int The number of splits. """ - return 0 + if not self.time_col == "index": + if np.dtype(X[self.time_col]) != np.dtype("datetime64[ns]"): + raise ValueError("Time column should be a datetime object") + X_mem = X.set_index(self.time_col).copy() + else: + X_mem = X.copy() + if X_mem.index.dtype != np.dtype("datetime64[ns]"): + raise ValueError("Time column should be a datetime object") + + n_split = len(X_mem.resample("ME")) - 1 + return n_split def split(self, X, y, groups=None): """Generate indices to split data into training and test set. @@ -177,12 +220,41 @@ def split(self, X, y, groups=None): idx_test : ndarray The testing set indices for that split. """ + if isinstance(X, pd.Series): + X = pd.DataFrame(X) - n_samples = X.shape[0] n_splits = self.get_n_splits(X, y, groups) + + if not self.time_col == "index": + if np.dtype(X[self.time_col]) != np.dtype("datetime64[ns]"): + raise ValueError("Time column should be a datetime object") + X_ = X.set_index(self.time_col).copy() + else: + X_ = X.copy() + if X_.index.dtype != np.dtype("datetime64[ns]"): + raise ValueError("Time column should be a datetime object") + + month_split = pd.unique(X_.to_period("M").index) + month_split = pd.Series(month_split) + + month_split = month_split.apply( + lambda x: "{}-{}".format(x.year, str(x.month).zfill(2)) + ) + + month_split.sort_values(inplace=True, ignore_index=True) + + X_mem = X_.copy().sort_index() + + X_.reset_index(names="date", inplace=True) + for i in range(n_splits): - idx_train = range(n_samples) - idx_test = range(n_samples) - yield ( - idx_train, idx_test - ) + mem_id_train = X_mem[: month_split[i]].index + + X_mem.drop(mem_id_train, inplace=True) + + mem_id_test = X_mem[: month_split[i + 1]].index + + idx_train = X_.index[(X_["date"].isin(mem_id_train))].to_list() + idx_test = X_.index[(X_["date"].isin(mem_id_test))].to_list() + + yield (idx_train, idx_test)