From 326fd641190f16f741a305d36b3cb4ceb547c846 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 12 May 2024 10:49:05 +0200 Subject: [PATCH] Allow to temporarily set the current registry even if it is not associated with a worker thread --- rayon-core/src/registry.rs | 71 ++++++++++++++++++++++++++++++++++--- rayon-core/src/scope/mod.rs | 24 ++++++++----- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index d30f815bd..8ecb4fbab 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -161,7 +161,7 @@ static THE_REGISTRY_SET: Once = Once::new(); /// Starts the worker threads (if that has not already happened). If /// initialization has not already occurred, use the default /// configuration. -pub(super) fn global_registry() -> &'static Arc { +fn global_registry() -> &'static Arc { set_global_registry(default_global_registry) .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) }) .expect("The global thread pool has not been initialized.") @@ -217,6 +217,36 @@ fn default_global_registry() -> Result, ThreadPoolBuildError> { result } +// This is used to temporarily overwrite the current registry. +// +// This either null, a pointer to the global registry if it was +// ever used to access the global registry or a pointer to a +// registry which is temporarily made current because the current +// thread is not a worker thread but is running a scope associated +// to a specific thread pool. +thread_local! { + static CURRENT_REGISTRY: Cell<*const Arc> = const { Cell::new(ptr::null()) }; +} + +#[cold] +fn set_current_registry_to_global_registry() -> *const Arc { + let global = global_registry(); + + CURRENT_REGISTRY.with(|current_registry| current_registry.set(global)); + + global +} + +pub(super) fn current_registry() -> *const Arc { + let mut current = CURRENT_REGISTRY.with(Cell::get); + + if current.is_null() { + current = set_current_registry_to_global_registry(); + } + + current +} + struct Terminator<'a>(&'a Arc); impl<'a> Drop for Terminator<'a> { @@ -315,7 +345,7 @@ impl Registry { unsafe { let worker_thread = WorkerThread::current(); let registry = if worker_thread.is_null() { - global_registry() + &*current_registry() } else { &(*worker_thread).registry }; @@ -323,6 +353,39 @@ impl Registry { } } + /// Optionally install a specific registry as the current one. + /// + /// This is used when a thread which is not a worker executes + /// a scope which should use the specific thread pool instead of + /// the global one. + pub(super) fn with_current(registry: Option<&Arc>, f: F) -> R + where + F: FnOnce() -> R, + { + struct Guard { + current: *const Arc, + } + + impl Guard { + fn new(registry: &Arc) -> Self { + let current = + CURRENT_REGISTRY.with(|current_registry| current_registry.replace(registry)); + + Self { current } + } + } + + impl Drop for Guard { + fn drop(&mut self) { + CURRENT_REGISTRY.with(|current_registry| current_registry.set(self.current)); + } + } + + let _guard = registry.map(Guard::new); + + f() + } + /// Returns the number of threads in the current registry. This /// is better than `Registry::current().num_threads()` because it /// avoids incrementing the `Arc`. @@ -330,7 +393,7 @@ impl Registry { unsafe { let worker_thread = WorkerThread::current(); if worker_thread.is_null() { - global_registry().num_threads() + (*current_registry()).num_threads() } else { (*worker_thread).registry.num_threads() } @@ -946,7 +1009,7 @@ where // invalidated until we return. op(&*owner_thread, false) } else { - global_registry().in_worker(op) + (*current_registry()).in_worker(op) } } } diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index 1d8732fea..8dd0234be 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -8,7 +8,7 @@ use crate::broadcast::BroadcastContext; use crate::job::{ArcJob, HeapJob, JobFifo, JobRef}; use crate::latch::{CountLatch, Latch}; -use crate::registry::{global_registry, in_worker, Registry, WorkerThread}; +use crate::registry::{current_registry, in_worker, Registry, WorkerThread}; use crate::unwind; use std::any::Any; use std::fmt; @@ -416,9 +416,11 @@ pub(crate) fn do_in_place_scope<'scope, OP, R>(registry: Option<&Arc>, where OP: FnOnce(&Scope<'scope>) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = Scope::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + Registry::with_current(registry, || { + let thread = unsafe { WorkerThread::current().as_ref() }; + let scope = Scope::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } /// Creates a "fork-join" scope `s` with FIFO order, and invokes the @@ -453,9 +455,11 @@ pub(crate) fn do_in_place_scope_fifo<'scope, OP, R>(registry: Option<&Arc) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = ScopeFifo::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + Registry::with_current(registry, || { + let thread = unsafe { WorkerThread::current().as_ref() }; + let scope = ScopeFifo::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } impl<'scope> Scope<'scope> { @@ -625,7 +629,11 @@ impl<'scope> ScopeBase<'scope> { fn new(owner: Option<&WorkerThread>, registry: Option<&Arc>) -> Self { let registry = registry.unwrap_or_else(|| match owner { Some(owner) => owner.registry(), - None => global_registry(), + // SAFETY: `current_registry` will either return a pointer to + // the global registry which has a 'static lifetime or + // to temporary one kept alive by `with_current`. + // In both case we can safely dereference it here to clone the `Arc`. + None => unsafe { &*current_registry() }, }); ScopeBase {