Skip to content

Commit 79da465

Browse files
glemaitreogriselthomasjpfan
authored andcommitted
FIX propagate configuration to workers in parallel (#25363)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 1d6d023 commit 79da465

40 files changed

+314
-116
lines changed

benchmarks/bench_saga.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import time
88
import os
99

10-
from joblib import Parallel
11-
from sklearn.utils.fixes import delayed
10+
from sklearn.utils.parallel import delayed, Parallel
1211
import matplotlib.pyplot as plt
1312
import numpy as np
1413

build_tools/azure/linting.sh

+10-5
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@ then
3434
exit 1
3535
fi
3636

37-
joblib_import="$(git grep -l -A 10 -E "joblib import.+delayed" -- "*.py" ":!sklearn/utils/_joblib.py" ":!sklearn/utils/fixes.py")"
38-
39-
if [ ! -z "$joblib_import" ]; then
40-
echo "Use from sklearn.utils.fixes import delayed instead of joblib delayed. The following files contains imports to joblib.delayed:"
41-
echo "$joblib_import"
37+
joblib_delayed_import="$(git grep -l -A 10 -E "joblib import.+delayed" -- "*.py" ":!sklearn/utils/_joblib.py" ":!sklearn/utils/parallel.py")"
38+
if [ ! -z "$joblib_delayed_import" ]; then
39+
echo "Use from sklearn.utils.parallel import delayed instead of joblib delayed. The following files contains imports to joblib.delayed:"
40+
echo "$joblib_delayed_import"
41+
exit 1
42+
fi
43+
joblib_Parallel_import="$(git grep -l -A 10 -E "joblib import.+Parallel" -- "*.py" ":!sklearn/utils/_joblib.py" ":!sklearn/utils/parallel.py")"
44+
if [ ! -z "$joblib_Parallel_import" ]; then
45+
echo "Use from sklearn.utils.parallel import Parallel instead of joblib Parallel. The following files contains imports to joblib.Parallel:"
46+
echo "$joblib_Parallel_import"
4247
exit 1
4348
fi

doc/modules/classes.rst

+7
Original file line numberDiff line numberDiff line change
@@ -1666,9 +1666,16 @@ Utilities from joblib:
16661666
:toctree: generated/
16671667
:template: function.rst
16681668

1669+
utils.parallel.delayed
16691670
utils.parallel_backend
16701671
utils.register_parallel_backend
16711672

1673+
.. autosummary::
1674+
:toctree: generated/
1675+
:template: class.rst
1676+
1677+
utils.parallel.Parallel
1678+
16721679

16731680
Recently deprecated
16741681
===================

doc/whats_new/v1.2.rst

+17
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ Version 1.2.1
99

1010
**In Development**
1111

12+
Changes impacting all modules
13+
-----------------------------
14+
15+
- |Fix| Fix a bug where the current configuration was ignored in estimators using
16+
`n_jobs > 1`. This bug was triggered for tasks dispatched by the auxillary
17+
thread of `joblib` as :func:`sklearn.get_config` used to access an empty thread
18+
local configuration instead of the configuration visible from the thread where
19+
`joblib.Parallel` was first called.
20+
:pr:`25363` by :user:`Guillaume Lemaitre <glemaitre>`.
21+
1222
Changed models
1323
--------------
1424

@@ -139,6 +149,13 @@ Changelog
139149
boolean. The type is maintained, instead of converting to `float64.`
140150
:pr:`25147` by :user:`Tim Head <betatim>`.
141151

152+
- |API| :func:`utils.fixes.delayed` is deprecated in 1.2.1 and will be removed
153+
in 1.5. Instead, import :func:`utils.parallel.delayed` and use it in
154+
conjunction with the newly introduced :func:`utils.parallel.Parallel`
155+
to ensure proper propagation of the scikit-learn configuration to
156+
the workers.
157+
:pr:`25363` by :user:`Guillaume Lemaitre <glemaitre>`.
158+
142159
.. _changes_1_2:
143160

144161
Version 1.2.0

sklearn/calibration.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from math import log
1616
import numpy as np
17-
from joblib import Parallel
1817

1918
from scipy.special import expit
2019
from scipy.special import xlogy
@@ -36,7 +35,7 @@
3635
)
3736

3837
from .utils.multiclass import check_classification_targets
39-
from .utils.fixes import delayed
38+
from .utils.parallel import delayed, Parallel
4039
from .utils._param_validation import StrOptions, HasMethods, Hidden
4140
from .utils.validation import (
4241
_check_fit_params,

sklearn/cluster/_mean_shift.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616

1717
import numpy as np
1818
import warnings
19-
from joblib import Parallel
2019
from numbers import Integral, Real
2120

2221
from collections import defaultdict
2322
from ..utils._param_validation import Interval
2423
from ..utils.validation import check_is_fitted
25-
from ..utils.fixes import delayed
24+
from ..utils.parallel import delayed, Parallel
2625
from ..utils import check_random_state, gen_batches, check_array
2726
from ..base import BaseEstimator, ClusterMixin
2827
from ..neighbors import NearestNeighbors

sklearn/compose/_column_transformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import numpy as np
1313
from scipy import sparse
14-
from joblib import Parallel
1514

1615
from ..base import clone, TransformerMixin
1716
from ..utils._estimator_html_repr import _VisualBlock
@@ -24,7 +23,7 @@
2423
from ..utils import check_pandas_support
2524
from ..utils.metaestimators import _BaseComposition
2625
from ..utils.validation import check_array, check_is_fitted, _check_feature_names_in
27-
from ..utils.fixes import delayed
26+
from ..utils.parallel import delayed, Parallel
2827

2928

3029
__all__ = ["ColumnTransformer", "make_column_transformer", "make_column_selector"]

sklearn/covariance/_graph_lasso.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from numbers import Integral, Real
1414
import numpy as np
1515
from scipy import linalg
16-
from joblib import Parallel
1716

1817
from . import empirical_covariance, EmpiricalCovariance, log_likelihood
1918

@@ -23,7 +22,7 @@
2322
check_random_state,
2423
check_scalar,
2524
)
26-
from ..utils.fixes import delayed
25+
from ..utils.parallel import delayed, Parallel
2726
from ..utils._param_validation import Interval, StrOptions
2827

2928
# mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast'

sklearn/decomposition/_dict_learning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313

1414
import numpy as np
1515
from scipy import linalg
16-
from joblib import Parallel, effective_n_jobs
16+
from joblib import effective_n_jobs
1717

1818
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
1919
from ..utils import check_array, check_random_state, gen_even_slices, gen_batches
2020
from ..utils import deprecated
2121
from ..utils._param_validation import Hidden, Interval, StrOptions
2222
from ..utils.extmath import randomized_svd, row_norms, svd_flip
2323
from ..utils.validation import check_is_fitted
24-
from ..utils.fixes import delayed
24+
from ..utils.parallel import delayed, Parallel
2525
from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars
2626

2727

sklearn/decomposition/_lda.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import numpy as np
1616
import scipy.sparse as sp
1717
from scipy.special import gammaln, logsumexp
18-
from joblib import Parallel, effective_n_jobs
18+
from joblib import effective_n_jobs
1919

2020
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
2121
from ..utils import check_random_state, gen_batches, gen_even_slices
2222
from ..utils.validation import check_non_negative
2323
from ..utils.validation import check_is_fitted
24-
from ..utils.fixes import delayed
24+
from ..utils.parallel import delayed, Parallel
2525
from ..utils._param_validation import Interval, StrOptions
2626

2727
from ._online_lda_fast import (

sklearn/decomposition/tests/test_dict_learning.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
from functools import partial
66
import itertools
77

8-
from joblib import Parallel
9-
108
import sklearn
119

1210
from sklearn.base import clone
1311

1412
from sklearn.exceptions import ConvergenceWarning
1513

1614
from sklearn.utils import check_array
15+
from sklearn.utils.parallel import Parallel
1716

1817
from sklearn.utils._testing import assert_allclose
1918
from sklearn.utils._testing import assert_array_almost_equal

sklearn/ensemble/_bagging.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from warnings import warn
1313
from functools import partial
1414

15-
from joblib import Parallel
16-
1715
from ._base import BaseEnsemble, _partition_estimators
1816
from ..base import ClassifierMixin, RegressorMixin
1917
from ..metrics import r2_score, accuracy_score
@@ -25,7 +23,7 @@
2523
from ..utils.random import sample_without_replacement
2624
from ..utils._param_validation import Interval, HasMethods, StrOptions
2725
from ..utils.validation import has_fit_parameter, check_is_fitted, _check_sample_weight
28-
from ..utils.fixes import delayed
26+
from ..utils.parallel import delayed, Parallel
2927

3028

3129
__all__ = ["BaggingClassifier", "BaggingRegressor"]

sklearn/ensemble/_forest.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class calls the ``fit`` method of each sub-estimator on random samples
4848
import numpy as np
4949
from scipy.sparse import issparse
5050
from scipy.sparse import hstack as sparse_hstack
51-
from joblib import Parallel
5251

5352
from ..base import is_classifier
5453
from ..base import ClassifierMixin, MultiOutputMixin, RegressorMixin, TransformerMixin
@@ -66,7 +65,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
6665
from ..utils import check_random_state, compute_sample_weight
6766
from ..exceptions import DataConversionWarning
6867
from ._base import BaseEnsemble, _partition_estimators
69-
from ..utils.fixes import delayed
68+
from ..utils.parallel import delayed, Parallel
7069
from ..utils.multiclass import check_classification_targets, type_of_target
7170
from ..utils.validation import (
7271
check_is_fitted,

sklearn/ensemble/_stacking.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from numbers import Integral
99

1010
import numpy as np
11-
from joblib import Parallel
1211
import scipy.sparse as sparse
1312

1413
from ..base import clone
@@ -33,7 +32,7 @@
3332
from ..utils.metaestimators import available_if
3433
from ..utils.validation import check_is_fitted
3534
from ..utils.validation import column_or_1d
36-
from ..utils.fixes import delayed
35+
from ..utils.parallel import delayed, Parallel
3736
from ..utils._param_validation import HasMethods, StrOptions
3837
from ..utils.validation import _check_feature_names_in
3938

sklearn/ensemble/_voting.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import numpy as np
2020

21-
from joblib import Parallel
22-
2321
from ..base import ClassifierMixin
2422
from ..base import RegressorMixin
2523
from ..base import TransformerMixin
@@ -36,7 +34,7 @@
3634
from ..utils._param_validation import StrOptions
3735
from ..exceptions import NotFittedError
3836
from ..utils._estimator_html_repr import _VisualBlock
39-
from ..utils.fixes import delayed
37+
from ..utils.parallel import delayed, Parallel
4038

4139

4240
class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble):

sklearn/ensemble/tests/test_forest.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717
from typing import Dict, Any
1818

1919
import numpy as np
20-
from joblib import Parallel
2120
from scipy.sparse import csr_matrix
2221
from scipy.sparse import csc_matrix
2322
from scipy.sparse import coo_matrix
2423
from scipy.special import comb
2524

26-
import pytest
27-
2825
import joblib
2926

27+
import pytest
28+
3029
import sklearn
3130
from sklearn.dummy import DummyRegressor
3231
from sklearn.metrics import mean_poisson_deviance
@@ -55,6 +54,7 @@
5554
>>>>>>> c3fca81536 (FIX Support read-only sparse datasets for `Tree`-based estimators (#25341))
5655
from sklearn.model_selection import GridSearchCV
5756
from sklearn.svm import LinearSVC
57+
from sklearn.utils.parallel import Parallel
5858
from sklearn.utils.validation import check_random_state
5959

6060
from sklearn.metrics import mean_squared_error

sklearn/feature_selection/_rfe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88

99
import numpy as np
1010
from numbers import Integral, Real
11-
from joblib import Parallel, effective_n_jobs
11+
from joblib import effective_n_jobs
1212

1313

1414
from ..utils.metaestimators import available_if
1515
from ..utils.metaestimators import _safe_split
1616
from ..utils._param_validation import HasMethods, Interval
1717
from ..utils._tags import _safe_tags
1818
from ..utils.validation import check_is_fitted
19-
from ..utils.fixes import delayed
19+
from ..utils.parallel import delayed, Parallel
2020
from ..base import BaseEstimator
2121
from ..base import MetaEstimatorMixin
2222
from ..base import clone

sklearn/inspection/_permutation_importance.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Permutation importance for estimators."""
22
import numbers
33
import numpy as np
4-
from joblib import Parallel
54

65
from ..ensemble._bagging import _generate_indices
76
from ..metrics import check_scoring
@@ -10,7 +9,7 @@
109
from ..utils import Bunch, _safe_indexing
1110
from ..utils import check_random_state
1211
from ..utils import check_array
13-
from ..utils.fixes import delayed
12+
from ..utils.parallel import delayed, Parallel
1413

1514

1615
def _weights_scorer(scorer, estimator, X, y, sample_weight):

sklearn/inspection/_plot/partial_dependence.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77
from scipy import sparse
88
from scipy.stats.mstats import mquantiles
9-
from joblib import Parallel
109

1110
from .. import partial_dependence
1211
from .._pd_utils import _check_feature_names, _get_feature_index
@@ -16,7 +15,7 @@
1615
from ...utils import check_matplotlib_support # noqa
1716
from ...utils import check_random_state
1817
from ...utils import _safe_indexing
19-
from ...utils.fixes import delayed
18+
from ...utils.parallel import delayed, Parallel
2019
from ...utils._encode import _unique
2120

2221

sklearn/linear_model/_base.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from scipy import sparse
2626
from scipy.sparse.linalg import lsqr
2727
from scipy.special import expit
28-
from joblib import Parallel
2928
from numbers import Integral
3029

3130
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin, MultiOutputMixin
@@ -40,7 +39,7 @@
4039
from ..utils._seq_dataset import ArrayDataset32, CSRDataset32
4140
from ..utils._seq_dataset import ArrayDataset64, CSRDataset64
4241
from ..utils.validation import check_is_fitted, _check_sample_weight
43-
from ..utils.fixes import delayed
42+
from ..utils.parallel import delayed, Parallel
4443

4544
# TODO: bayesian_ridge_regression and bayesian_regression_ard
4645
# should be squashed into its respective objects.

sklearn/linear_model/_coordinate_descent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616
from scipy import sparse
17-
from joblib import Parallel, effective_n_jobs
17+
from joblib import effective_n_jobs
1818

1919
from ._base import LinearModel, _pre_fit
2020
from ..base import RegressorMixin, MultiOutputMixin
@@ -30,7 +30,7 @@
3030
check_is_fitted,
3131
column_or_1d,
3232
)
33-
from ..utils.fixes import delayed
33+
from ..utils.parallel import delayed, Parallel
3434

3535
# mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast'
3636
from . import _cd_fast as cd_fast # type: ignore

sklearn/linear_model/_least_angle.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import numpy as np
1717
from scipy import linalg, interpolate
1818
from scipy.linalg.lapack import get_lapack_funcs
19-
from joblib import Parallel
2019

2120
from ._base import LinearModel, LinearRegression
2221
from ._base import _deprecate_normalize, _preprocess_data
@@ -28,7 +27,7 @@
2827
from ..utils._param_validation import Hidden, Interval, StrOptions
2928
from ..model_selection import check_cv
3029
from ..exceptions import ConvergenceWarning
31-
from ..utils.fixes import delayed
30+
from ..utils.parallel import delayed, Parallel
3231

3332
SOLVE_TRIANGULAR_ARGS = {"check_finite": False}
3433

0 commit comments

Comments
 (0)