diff --git a/sklearn_questions.py b/sklearn_questions.py index fa02e0d..1f668bd 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -55,13 +55,13 @@ from sklearn.model_selection import BaseCrossValidator -from sklearn.utils.validation import check_X_y, check_is_fitted +from sklearn.utils.validation import validate_data, check_is_fitted from sklearn.utils.validation import check_array from sklearn.utils.multiclass import check_classification_targets from sklearn.metrics.pairwise import pairwise_distances -class KNearestNeighbors(BaseEstimator, ClassifierMixin): +class KNearestNeighbors(ClassifierMixin, BaseEstimator): """KNearestNeighbors classifier.""" def __init__(self, n_neighbors=1): # noqa: D107 @@ -82,11 +82,15 @@ def fit(self, X, y): self : instance of KNearestNeighbors The current instance of the classifier """ + X, y = validate_data(self, X, y) + check_classification_targets(y) + self.X_train_ = X + self.y_train_ = y + self.classes_ = np.unique(y) return self def predict(self, X): """Predict function. - Parameters ---------- X : ndarray, shape (n_test_samples, n_features) @@ -97,7 +101,15 @@ def predict(self, X): y : ndarray, shape (n_test_samples,) Predicted class labels for each test data sample. """ - y_pred = np.zeros(X.shape[0]) + check_is_fitted(self, ["X_train_", "y_train_", "classes_"]) + X = validate_data(self, X, reset=False) + y_pred = [] + for _, x in enumerate(X): + distances = pairwise_distances(x.reshape(1, -1), self.X_train_) + idx = np.argsort(distances, axis=1)[0][:self.n_neighbors] + values, counts = np.unique(self.y_train_[idx], return_counts=True) + y_pred.append(values[np.argmax(counts)]) + y_pred = np.array(y_pred) return y_pred def score(self, X, y): @@ -115,7 +127,11 @@ def score(self, X, y): score : float Accuracy of the model computed for the (X, y) pairs. """ - return 0. + check_is_fitted(self) + X = validate_data(self, X, reset=False) + y_pred = self.predict(X) + return np.mean(y_pred == y) + class MonthlySplit(BaseCrossValidator): @@ -133,6 +149,7 @@ class MonthlySplit(BaseCrossValidator): for which this column is not a datetime, it will raise a ValueError. To use the index as column just set `time_col` to `'index'`. """ + def __init__(self, time_col='index'): # noqa: D107 self.time_col = time_col @@ -155,7 +172,16 @@ def get_n_splits(self, X, y=None, groups=None): n_splits : int The number of splits. """ - return 0 + if self.time_col == 'index': + X_time = X.reset_index() + else: + X_time = X.copy() + if X_time[self.time_col].dtype != 'datetime64[ns]': + raise ValueError(f"Column '{self.time_col}' is not a datetime.") + sorted = X_time.sort_values(by=self.time_col) + n_splits = len(sorted[self.time_col].dt.to_period('M').unique()) - 1 + return n_splits + def split(self, X, y, groups=None): """Generate indices to split data into training and test set. @@ -177,12 +203,14 @@ def split(self, X, y, groups=None): idx_test : ndarray The testing set indices for that split. """ - - n_samples = X.shape[0] - n_splits = self.get_n_splits(X, y, groups) + X_copy = X.reset_index() + n_splits = self.get_n_splits(X_copy, y, groups) + X_grouped = ( + X_copy.sort_values(by=self.time_col) + .groupby(pd.Grouper(key=self.time_col, freq="M")) + ) + idxs = [group.index for _, group in X_grouped] for i in range(n_splits): - idx_train = range(n_samples) - idx_test = range(n_samples) - yield ( - idx_train, idx_test - ) + idx_train = list(idxs[i]) + idx_test = list(idxs[i+1]) + yield (idx_train, idx_test) \ No newline at end of file