Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix(prediction error plot): Do not cache plot if random_state is None #1394

Merged
merged 7 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions skore/src/skore/sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,12 +1201,15 @@ def _get_display(
display : display_class
The display.
"""
# build the cache key components to finally create a tuple that will be used
# to check if the metric has already been computed
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
cache_key_parts.extend(display_kwargs.values())
cache_key_parts.append(data_source)
cache_key = tuple(cache_key_parts)
if "random_state" in display_kwargs and display_kwargs["random_state"] is None:
cache_key = None
else:
# build the cache key components to finally create a tuple that will be used
# to check if the metric has already been computed
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
cache_key_parts.extend(display_kwargs.values())
cache_key_parts.append(data_source)
cache_key = tuple(cache_key_parts)

assert self._progress_info is not None, "Progress info not set"
progress = self._progress_info["current_progress"]
Expand Down Expand Up @@ -1250,7 +1253,11 @@ def _get_display(
data_source=data_source,
**display_kwargs,
)
self._parent._cache[cache_key] = display

# Unless random_state is an int (i.e. the call is deterministic),
# we do not cache
if cache_key is not None:
self._parent._cache[cache_key] = display

return display

Expand Down
19 changes: 13 additions & 6 deletions skore/src/skore/sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,11 +959,14 @@ def _get_display(
display : display_class
The display.
"""
# Create a list of cache key components and then convert to tuple
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
cache_key_parts.extend(display_kwargs.values())
cache_key_parts.append(data_source)
cache_key = tuple(cache_key_parts)
if "random_state" in display_kwargs and display_kwargs["random_state"] is None:
cache_key = None
else:
# Create a list of cache key components and then convert to tuple
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
cache_key_parts.extend(display_kwargs.values())
cache_key_parts.append(data_source)
cache_key = tuple(cache_key_parts)

assert self._progress_info is not None, "Progress info not set"
progress = self._progress_info["current_progress"]
Expand Down Expand Up @@ -1003,7 +1006,11 @@ def _get_display(
data_source=data_source,
**display_kwargs,
)
self._parent._cache[cache_key] = display

# Unless random_state is an int (i.e. the call is deterministic),
# we do not cache
if cache_key is not None:
self._parent._cache[cache_key] = display

return display

Expand Down
22 changes: 15 additions & 7 deletions skore/src/skore/sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,13 +1611,17 @@ def _get_display(

# build the cache key components to finally create a tuple that will be used
# to check if the metric has already been computed
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
cache_key_parts.extend(display_kwargs.values())
if data_source_hash is not None:
cache_key_parts.append(data_source_hash)

if "random_state" in display_kwargs and display_kwargs["random_state"] is None:
cache_key = None
else:
cache_key_parts.append(data_source)
cache_key = tuple(cache_key_parts)
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
cache_key_parts.extend(display_kwargs.values())
if data_source_hash is not None:
cache_key_parts.append(data_source_hash)
else:
cache_key_parts.append(data_source)
cache_key = tuple(cache_key_parts)

if cache_key in self._parent._cache:
display = self._parent._cache[cache_key]
Expand All @@ -1642,7 +1646,11 @@ def _get_display(
data_source=data_source,
**display_kwargs,
)
self._parent._cache[cache_key] = display

# Unless random_state is an int (i.e. the call is deterministic),
# we do not cache
if cache_key is not None:
self._parent._cache[cache_key] = display

return display

Expand Down
18 changes: 18 additions & 0 deletions skore/tests/unit/sklearn/comparison/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,21 @@ def test_comparison_report_plots(

# Ensure plot is callable
display.plot()


def test_random_state(regression_model):
"""If random_state is None (the default) the call should not be cached."""
estimator, X_train, X_test, y_train, y_test = regression_model
estimator_report = EstimatorReport(
estimator,
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
)

report = ComparisonReport([estimator_report, estimator_report])

report.metrics.prediction_error()
# skore should store the y_pred of the internal estimators, but not the plot
assert report._cache == {}
22 changes: 10 additions & 12 deletions skore/tests/unit/sklearn/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_estimator_report_plot_roc(binary_classification_data):
def test_estimator_report_display_binary_classification(
pyplot, binary_classification_data, display
):
"""General behaviour of the function creating display on binary classification."""
"""The call to display functions should be cached."""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, display)
Expand All @@ -369,23 +369,22 @@ def test_estimator_report_display_binary_classification(

@pytest.mark.parametrize("display", ["prediction_error"])
def test_estimator_report_display_regression(pyplot, regression_data, display):
"""General behaviour of the function creating display on regression."""
"""The call to display functions should be cached, as long as the arguments make it
reproducible."""
estimator, X_test, y_test = regression_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, display)
display_first_call = getattr(report.metrics, display)()
display_first_call = getattr(report.metrics, display)(random_state=0)
assert report._cache != {}
display_second_call = getattr(report.metrics, display)()
display_second_call = getattr(report.metrics, display)(random_state=0)
assert display_first_call is display_second_call


@pytest.mark.parametrize("display", ["roc", "precision_recall"])
def test_estimator_report_display_binary_classification_external_data(
pyplot, binary_classification_data, display
):
"""General behaviour of the function creating display on binary classification
when passing external data.
"""
"""The call to display functions should be cached when passing external data."""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator)
assert hasattr(report.metrics, display)
Expand All @@ -403,18 +402,17 @@ def test_estimator_report_display_binary_classification_external_data(
def test_estimator_report_display_regression_external_data(
pyplot, regression_data, display
):
"""General behaviour of the function creating display on regression when passing
external data.
"""
"""The call to display functions should be cached when passing external data,
as long as the arguments make it reproducible."""
estimator, X_test, y_test = regression_data
report = EstimatorReport(estimator)
assert hasattr(report.metrics, display)
display_first_call = getattr(report.metrics, display)(
data_source="X_y", X=X_test, y=y_test
data_source="X_y", X=X_test, y=y_test, random_state=0
)
assert report._cache != {}
display_second_call = getattr(report.metrics, display)(
data_source="X_y", X=X_test, y=y_test
data_source="X_y", X=X_test, y=y_test, random_state=0
)
assert display_first_call is display_second_call

Expand Down
11 changes: 11 additions & 0 deletions skore/tests/unit/sklearn/plot/test_prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,14 @@ def test_prediction_error_display_kwargs(pyplot, regression_data):
display = report.metrics.prediction_error(subsample=0.5)
display.plot()
assert len(display.scatter_.get_offsets()) == expected_subsample


def test_random_state(regression_data):
"""If random_state is None (the default) the call should not be cached."""
estimator, X_train, X_test, y_train, y_test = regression_data
report = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
report.metrics.prediction_error()
# skore should store the y_pred, but not the plot
assert len(report._cache) == 1
14 changes: 12 additions & 2 deletions skore/tests/unit/sklearn/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,22 @@ def test_cross_validation_report_display_regression(pyplot, regression_data, dis
estimator, X, y = regression_data
report = CrossValidationReport(estimator, X, y, cv_splitter=2)
assert hasattr(report.metrics, display)
display_first_call = getattr(report.metrics, display)()
display_first_call = getattr(report.metrics, display)(random_state=0)
assert report._cache != {}
display_second_call = getattr(report.metrics, display)()
display_second_call = getattr(report.metrics, display)(random_state=0)
assert display_first_call is display_second_call


def test_random_state(regression_data):
"""If random_state is None (the default) the call should not be cached."""
estimator, X, y = regression_data
report = CrossValidationReport(estimator, X, y, cv_splitter=2)

report.metrics.prediction_error()
# skore should store the y_pred of the internal estimators, but not the plot
assert report._cache == {}


########################################################################################
# Check the metrics methods
########################################################################################
Expand Down