Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: Add cache_predictions method to ComparisonReport #1352

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions skore/src/skore/sklearn/_comparison/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skore.externals._pandas_accessors import DirNamesMixin
from skore.sklearn._base import _BaseReport
from skore.sklearn._estimator.report import EstimatorReport
from skore.utils._progress_bar import progress_decorator


class ComparisonReport(_BaseReport, DirNamesMixin):
Expand Down Expand Up @@ -144,6 +145,9 @@ def __init__(

self.estimator_reports_ = reports

# used to know if a parent launches a progress bar manager
self._parent_progress = None

# NEEDED FOR METRICS ACCESSOR
self.n_jobs = n_jobs
self._rng = np.random.default_rng(time.time_ns())
Expand All @@ -153,6 +157,103 @@ def __init__(
self._cache = {}
self._ml_task = self.estimator_reports_[0]._ml_task

def clear_cache(self):
"""Clear the cache.
Examples
--------
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport
>>> X, y = make_classification(random_state=42)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = LogisticRegression()
>>> estimator_report_1 = EstimatorReport(
... estimator_1,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> estimator_2 = LogisticRegression(C=2) # Different regularization
>>> estimator_report_2 = EstimatorReport(
... estimator_2,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> report = ComparisonReport([estimator_report_1, estimator_report_2])
>>> report.cache_predictions()
>>> report.clear_cache()
>>> report._cache
{}
"""
for report in self.estimator_reports_:
report.clear_cache()
self._cache = {}

@progress_decorator(description="Estimator predictions")
def cache_predictions(self, response_methods="auto", n_jobs=None):
"""Cache the predictions for sub-estimators reports.
Parameters
----------
response_methods : {"auto", "predict", "predict_proba", "decision_function"},\
default="auto
The methods to use to compute the predictions.
n_jobs : int, default=None
The number of jobs to run in parallel. If `None`, we use the `n_jobs`
parameter when initializing the report.
Examples
--------
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport
>>> X, y = make_classification(random_state=42)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = LogisticRegression()
>>> estimator_report_1 = EstimatorReport(
... estimator_1,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> estimator_2 = LogisticRegression(C=2) # Different regularization
>>> estimator_report_2 = EstimatorReport(
... estimator_2,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> report = ComparisonReport([estimator_report_1, estimator_report_2])
>>> report.cache_predictions()
>>> report._cache
{...}
"""
if n_jobs is None:
n_jobs = self.n_jobs

progress = self._progress_info["current_progress"]
main_task = self._progress_info["current_task"]

total_estimators = len(self.estimator_reports_)
progress.update(main_task, total=total_estimators)

for estimator_report in self.estimator_reports_:
# Pass the progress manager to child tasks
estimator_report._parent_progress = progress
estimator_report.cache_predictions(
response_methods=response_methods, n_jobs=n_jobs
)
progress.update(main_task, advance=1, refresh=True)

####################################################################################
# Methods related to the help and repr
####################################################################################
Expand Down