-
Notifications
You must be signed in to change notification settings - Fork 28.5k
[SPARK-51711][ML][PYTHON][CONNECT] Propagates the active remote spark session to new threads to fix CrossValidator #50507
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: master
Are you sure you want to change the base?
Conversation
Hi @zhengruifeng , could you review this PR? It fixes a bug in SparkML via SparkConnect. The bug is reproducible with the code example in the PR description. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -434,7 +434,7 @@ def _fit(self, dataset: Union[pd.DataFrame, DataFrame]) -> "CrossValidatorModel" | |||
|
|||
tasks = _parallelFitTasks(est, train, eva, validation, epm) | |||
if not is_remote(): | |||
tasks = list(map(inheritable_thread_target, tasks)) | |||
tasks = list(map(inheritable_thread_target(dataset.sparkSession), tasks)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary?
it is under if not is_remote()
branch
@@ -86,7 +86,7 @@ private[connect] class MLCache extends Logging { | |||
|
|||
private[connect] object MLCache { | |||
// The maximum number of distinct items in the cache. | |||
private val MAX_CACHED_ITEMS = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this related?
if no, I think we should file a separate PR to increase the value
Co-authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
@xi-db have we tried just replacing |
What changes were proposed in this pull request?
In SparkML with Spark Connect, the
_parallelFitTasks
fails when runningCrossValidator
fitting, as the active remote spark session is not properly propagated to the new threads.Before the PR, this code will fail in the line
cvModel = cv.fit(data)
:It fails because the active remote spark session is not properly set on that thread:
With this fix, the above code snippet works correctly.
Why are the changes needed?
It fixes a bug with CrossValidator fitting.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
New test.
Was this patch authored or co-authored using generative AI tooling?
No.