-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add config manager to manage progress bar (#1334)
closes #1264 Add a config management functions (i.e. `get_config`, `set_config` and `config_context`) to globally set up some behaviour. We first intend to disable the progress bar globally using this helper. From scikit-learn, there are corner case to propagate such configuration in parallel processing. We therefore use the same pattern by implementing `Parallel` and `delayed` and propagate our configuration within those function.
- Loading branch information
Showing
10 changed files
with
423 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
"""Global configuration state and functions for management.""" | ||
|
||
import threading | ||
import time | ||
from contextlib import contextmanager | ||
|
||
_global_config = { | ||
"show_progress": True, | ||
} | ||
_threadlocal = threading.local() | ||
|
||
|
||
def _get_threadlocal_config(): | ||
"""Get a threadlocal **mutable** configuration. | ||
If the configuration does not exist, copy the default global configuration. | ||
""" | ||
if not hasattr(_threadlocal, "global_config"): | ||
_threadlocal.global_config = _global_config.copy() | ||
return _threadlocal.global_config | ||
|
||
|
||
def get_config(): | ||
"""Retrieve current values for configuration set by :func:`set_config`. | ||
Returns | ||
------- | ||
config : dict | ||
Keys are parameter names that can be passed to :func:`set_config`. | ||
See Also | ||
-------- | ||
config_context : Context manager for global skore configuration. | ||
set_config : Set global skore configuration. | ||
Examples | ||
-------- | ||
>>> import skore | ||
>>> config = skore.get_config() | ||
>>> config.keys() | ||
dict_keys([...]) | ||
""" | ||
# Return a copy of the threadlocal configuration so that users will | ||
# not be able to modify the configuration with the returned dict. | ||
return _get_threadlocal_config().copy() | ||
|
||
|
||
def set_config( | ||
show_progress: bool = None, | ||
): | ||
"""Set global skore configuration. | ||
Parameters | ||
---------- | ||
show_progress : bool, default=None | ||
If True, show progress bars. Otherwise, do not show them. | ||
See Also | ||
-------- | ||
config_context : Context manager for global skore configuration. | ||
get_config : Retrieve current values of the global configuration. | ||
Examples | ||
-------- | ||
>>> from skore import set_config | ||
>>> set_config(show_progress=False) # doctest: +SKIP | ||
""" | ||
local_config = _get_threadlocal_config() | ||
|
||
if show_progress is not None: | ||
local_config["show_progress"] = show_progress | ||
|
||
|
||
@contextmanager | ||
def config_context( | ||
*, | ||
show_progress: bool = None, | ||
): | ||
"""Context manager for global skore configuration. | ||
Parameters | ||
---------- | ||
show_progress : bool, default=None | ||
If True, show progress bars. Otherwise, do not show them. | ||
Yields | ||
------ | ||
None. | ||
See Also | ||
-------- | ||
set_config : Set global skore configuration. | ||
get_config : Retrieve current values of the global configuration. | ||
Notes | ||
----- | ||
All settings, not just those presently modified, will be returned to | ||
their previous values when the context manager is exited. | ||
Examples | ||
-------- | ||
>>> import skore | ||
>>> from sklearn.datasets import make_classification | ||
>>> from sklearn.model_selection import train_test_split | ||
>>> from sklearn.linear_model import LogisticRegression | ||
>>> from skore import CrossValidationReport | ||
>>> with skore.config_context(show_progress=False): | ||
... X, y = make_classification(random_state=42) | ||
... estimator = LogisticRegression() | ||
... report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=2) | ||
""" | ||
old_config = get_config() | ||
set_config( | ||
show_progress=show_progress, | ||
) | ||
|
||
try: | ||
yield | ||
finally: | ||
set_config(**old_config) | ||
|
||
|
||
def _set_show_progress_for_testing(show_progress, sleep_duration): | ||
"""Set the value of show_progress for testing purposes after some waiting. | ||
This function should exist in a Python module rather than in tests, otherwise | ||
joblib will not be able to pickle it. | ||
""" | ||
with config_context(show_progress=show_progress): | ||
time.sleep(sleep_duration) | ||
return get_config()["show_progress"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
"""Customizations of :mod:`joblib` and :mod:`threadpoolctl` tools for skore usage.""" | ||
|
||
import functools | ||
import warnings | ||
from functools import update_wrapper | ||
|
||
import joblib | ||
|
||
from skore._config import config_context, get_config | ||
|
||
# Global threadpool controller instance that can be used to locally limit the number of | ||
# threads without looping through all shared libraries every time. | ||
# It should not be accessed directly and _get_threadpool_controller should be used | ||
# instead. | ||
_threadpool_controller = None | ||
|
||
|
||
def _with_config_and_warning_filters(delayed_func, config, warning_filters): | ||
"""Attach a config to a delayed function.""" | ||
if hasattr(delayed_func, "with_config_and_warning_filters"): | ||
return delayed_func.with_config_and_warning_filters(config, warning_filters) | ||
else: | ||
warnings.warn( | ||
( | ||
"`skore.utils._parallel.Parallel` needs to be used in " | ||
"conjunction with `skore.utils._parallel.delayed` instead of " | ||
"`joblib.delayed` to correctly propagate the skore configuration to " | ||
"the joblib workers." | ||
), | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
return delayed_func | ||
|
||
|
||
class Parallel(joblib.Parallel): | ||
"""Tweak of :class:`joblib.Parallel` that propagates the skore configuration. | ||
This subclass of :class:`joblib.Parallel` ensures that the active configuration | ||
(thread-local) of skore is propagated to the parallel workers for the | ||
duration of the execution of the parallel tasks. | ||
The API does not change and you can refer to :class:`joblib.Parallel` | ||
documentation for more details. | ||
""" | ||
|
||
def __call__(self, iterable): | ||
"""Dispatch the tasks and return the results. | ||
Parameters | ||
---------- | ||
iterable : iterable | ||
Iterable containing tuples of (delayed_function, args, kwargs) that should | ||
be consumed. | ||
Returns | ||
------- | ||
results : list | ||
List of results of the tasks. | ||
""" | ||
# Capture the thread-local skore configuration at the time | ||
# Parallel.__call__ is issued since the tasks can be dispatched | ||
# in a different thread depending on the backend and on the value of | ||
# pre_dispatch and n_jobs. | ||
config = get_config() | ||
warning_filters = warnings.filters | ||
iterable_with_config_and_warning_filters = ( | ||
( | ||
_with_config_and_warning_filters(delayed_func, config, warning_filters), | ||
args, | ||
kwargs, | ||
) | ||
for delayed_func, args, kwargs in iterable | ||
) | ||
return super().__call__(iterable_with_config_and_warning_filters) | ||
|
||
|
||
# remove when https://github.com/joblib/joblib/issues/1071 is fixed | ||
def delayed(function): | ||
"""Capture the arguments of a function to delay its execution. | ||
This alternative to `joblib.delayed` is meant to be used in conjunction | ||
with `skore.utils._parallel.Parallel`. The latter captures the skore | ||
configuration by calling `skore.get_config()` in the current thread, prior to | ||
dispatching the first task. The captured configuration is then propagated and | ||
enabled for the duration of the execution of the delayed function in the | ||
joblib workers. | ||
Parameters | ||
---------- | ||
function : callable | ||
The function to be delayed. | ||
Returns | ||
------- | ||
output: tuple | ||
Tuple containing the delayed function, the positional arguments, and the | ||
keyword arguments. | ||
""" | ||
|
||
@functools.wraps(function) | ||
def delayed_function(*args, **kwargs): | ||
return _FuncWrapper(function), args, kwargs | ||
|
||
return delayed_function | ||
|
||
|
||
class _FuncWrapper: | ||
"""Load the global configuration before calling the function.""" | ||
|
||
def __init__(self, function): | ||
self.function = function | ||
update_wrapper(self, self.function) | ||
|
||
def with_config_and_warning_filters(self, config, warning_filters): | ||
self.config = config | ||
self.warning_filters = warning_filters | ||
return self | ||
|
||
def __call__(self, *args, **kwargs): | ||
config = getattr(self, "config", {}) | ||
warning_filters = getattr(self, "warning_filters", []) | ||
if not config or not warning_filters: | ||
warnings.warn( | ||
( | ||
"`skore.utils._parallel.delayed` should be used with" | ||
" `skore.utils._parallel.Parallel` to make it possible to" | ||
" propagate the skore configuration of the current thread to" | ||
" the joblib workers." | ||
), | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
|
||
with config_context(**config), warnings.catch_warnings(): | ||
warnings.filters = warning_filters | ||
return self.function(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.