diff --git a/sklearn_questions.py b/sklearn_questions.py index fa02e0d..90bb67e 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -55,13 +55,12 @@ from sklearn.model_selection import BaseCrossValidator -from sklearn.utils.validation import check_X_y, check_is_fitted -from sklearn.utils.validation import check_array -from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.validation import check_is_fitted, validate_data +from sklearn.utils.multiclass import unique_labels from sklearn.metrics.pairwise import pairwise_distances -class KNearestNeighbors(BaseEstimator, ClassifierMixin): +class KNearestNeighbors(ClassifierMixin, BaseEstimator): """KNearestNeighbors classifier.""" def __init__(self, n_neighbors=1): # noqa: D107 @@ -82,6 +81,9 @@ def fit(self, X, y): self : instance of KNearestNeighbors The current instance of the classifier """ + self.X_train_, self.y_train_ = validate_data(self, X, y) + self.classes_ = unique_labels(y) + self.is_fitted_ = True return self def predict(self, X): @@ -97,7 +99,18 @@ def predict(self, X): y : ndarray, shape (n_test_samples,) Predicted class labels for each test data sample. """ - y_pred = np.zeros(X.shape[0]) + from collections import Counter + + check_is_fitted(self, ['X_train_', 'y_train_']) + X = validate_data(self, X, reset=False) + + dist = pairwise_distances(self.X_train_, X, metric="euclidean") + y_pred = np.empty(X.shape[0], dtype=self.y_train_.dtype) + for i in range(len(X)): + idx_nearest = np.argsort(dist[:, i])[:self.n_neighbors] + labels = self.y_train_[idx_nearest] + most_common_label = Counter(labels).most_common(1)[0][0] + y_pred[i] = most_common_label return y_pred def score(self, X, y): @@ -115,7 +128,10 @@ def score(self, X, y): score : float Accuracy of the model computed for the (X, y) pairs. """ - return 0. + from sklearn.metrics import accuracy_score + + y_pred = self.predict(X) + return accuracy_score(y, y_pred) class MonthlySplit(BaseCrossValidator): @@ -155,7 +171,14 @@ def get_n_splits(self, X, y=None, groups=None): n_splits : int The number of splits. """ - return 0 + if self.time_col != 'index': + X = X.set_index(self.time_col) + + if not isinstance(X.index, pd.DatetimeIndex): + X.index = pd.to_datetime(X.index) + + groups = X.groupby(by=[X.index.year, X.index.month]) + return len(groups.groups.keys())-1 def split(self, X, y, groups=None): """Generate indices to split data into training and test set. @@ -177,12 +200,24 @@ def split(self, X, y, groups=None): idx_test : ndarray The testing set indices for that split. """ + if self.time_col != 'index': + X = X.set_index(self.time_col) - n_samples = X.shape[0] - n_splits = self.get_n_splits(X, y, groups) - for i in range(n_splits): - idx_train = range(n_samples) - idx_test = range(n_samples) - yield ( - idx_train, idx_test + if not isinstance(X.index, pd.DatetimeIndex): + X.index = pd.to_datetime(X.index) + + groups = X.groupby(by=[X.index.year, X.index.month]) + n_splits = len(groups.groups.keys()) - 1 + + if n_splits < 1: + raise ValueError( + "Insufficient data to create splits based on datetime column" ) + + for i in range(n_splits): + idx_tr = X.index.get_indexer_for( + groups.groups[list(groups.groups.keys())[i]]) + idx_te = X.index.get_indexer_for( + groups.groups[list(groups.groups.keys())[i + 1]]) + + yield (idx_tr, idx_te)