diff --git a/skore/src/skore/sklearn/_comparison/report.py b/skore/src/skore/sklearn/_comparison/report.py index ac4c3a30f..68dd58803 100644 --- a/skore/src/skore/sklearn/_comparison/report.py +++ b/skore/src/skore/sklearn/_comparison/report.py @@ -10,6 +10,7 @@ from skore.externals._pandas_accessors import DirNamesMixin from skore.sklearn._base import _BaseReport from skore.sklearn._estimator.report import EstimatorReport +from skore.utils._progress_bar import progress_decorator class ComparisonReport(_BaseReport, DirNamesMixin): @@ -144,6 +145,9 @@ def __init__( self.estimator_reports_ = reports + # used to know if a parent launches a progress bar manager + self._parent_progress = None + # NEEDED FOR METRICS ACCESSOR self.n_jobs = n_jobs self._rng = np.random.default_rng(time.time_ns()) @@ -153,6 +157,103 @@ def __init__( self._cache = {} self._ml_task = self.estimator_reports_[0]._ml_task + def clear_cache(self): + """Clear the cache. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import train_test_split + >>> from skore import ComparisonReport + >>> X, y = make_classification(random_state=42) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + >>> estimator_1 = LogisticRegression() + >>> estimator_report_1 = EstimatorReport( + ... estimator_1, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test + ... ) + >>> estimator_2 = LogisticRegression(C=2) # Different regularization + >>> estimator_report_2 = EstimatorReport( + ... estimator_2, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test + ... ) + >>> report = ComparisonReport([estimator_report_1, estimator_report_2]) + >>> report.cache_predictions() + >>> report.clear_cache() + >>> report._cache + {} + """ + for report in self.estimator_reports_: + report.clear_cache() + self._cache = {} + + @progress_decorator(description="Estimator predictions") + def cache_predictions(self, response_methods="auto", n_jobs=None): + """Cache the predictions for sub-estimators reports. + + Parameters + ---------- + response_methods : {"auto", "predict", "predict_proba", "decision_function"},\ + default="auto + The methods to use to compute the predictions. + + n_jobs : int, default=None + The number of jobs to run in parallel. If `None`, we use the `n_jobs` + parameter when initializing the report. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import train_test_split + >>> from skore import ComparisonReport + >>> X, y = make_classification(random_state=42) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + >>> estimator_1 = LogisticRegression() + >>> estimator_report_1 = EstimatorReport( + ... estimator_1, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test + ... ) + >>> estimator_2 = LogisticRegression(C=2) # Different regularization + >>> estimator_report_2 = EstimatorReport( + ... estimator_2, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test + ... ) + >>> report = ComparisonReport([estimator_report_1, estimator_report_2]) + >>> report.cache_predictions() + >>> report._cache + {...} + """ + if n_jobs is None: + n_jobs = self.n_jobs + + progress = self._progress_info["current_progress"] + main_task = self._progress_info["current_task"] + + total_estimators = len(self.estimator_reports_) + progress.update(main_task, total=total_estimators) + + for estimator_report in self.estimator_reports_: + # Pass the progress manager to child tasks + estimator_report._parent_progress = progress + estimator_report.cache_predictions( + response_methods=response_methods, n_jobs=n_jobs + ) + progress.update(main_task, advance=1, refresh=True) + #################################################################################### # Methods related to the help and repr ####################################################################################