Skip to content

Commit 6410e79

Browse files
committed
Add the ability to customize thread spawning
As an alternative to `ThreadPoolBuilder::build()` and `build_global()`, the new `spawn()` and `spawn_global()` methods take a closure which will be responsible for spawning the actual threads. This is called with a `ThreadBuilder` argument that provides the thread index, name, and stack size, with the expectation to call its `run()` method in the new thread. The motivating use cases for this are: - experimental WASM threading, to be externally implemented. - scoped threads, like the new test using `scoped_tls`.
1 parent 939d3ee commit 6410e79

File tree

5 files changed

+222
-36
lines changed

5 files changed

+222
-36
lines changed

rayon-core/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ num_cpus = "1.2"
1919
lazy_static = "1"
2020
crossbeam-deque = "0.6.3"
2121
crossbeam-queue = "0.1.2"
22+
crossbeam-utils = "0.6.5"
2223

2324
[dev-dependencies]
2425
rand = "0.6"
2526
rand_xorshift = "0.1"
27+
scoped-tls = "1.0"
2628

2729
[target.'cfg(unix)'.dev-dependencies]
2830
libc = "0.2"
@@ -49,3 +51,7 @@ path = "tests/scope_join.rs"
4951
[[test]]
5052
name = "simple_panic"
5153
path = "tests/simple_panic.rs"
54+
55+
[[test]]
56+
name = "scoped_threadpool"
57+
path = "tests/scoped_threadpool.rs"

rayon-core/src/lib.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ mod test;
6464
#[cfg(rayon_unstable)]
6565
pub mod internal;
6666
pub use join::{join, join_context};
67+
pub use registry::ThreadBuilder;
6768
pub use scope::{scope, Scope};
6869
pub use scope::{scope_fifo, ScopeFifo};
6970
pub use spawn::{spawn, spawn_fifo};
@@ -185,6 +186,21 @@ impl ThreadPoolBuilder {
185186
ThreadPool::build(self)
186187
}
187188

189+
/// Create a new `ThreadPool` initialized using this configuration and a
190+
/// custom function for spawning threads.
191+
///
192+
/// Note that the threads will not exit until after the pool is dropped. It
193+
/// is up to the caller to wait for thread termination if that is important
194+
/// for any invariants. For instance, threads created in `crossbeam::scope`
195+
/// will be joined before that scope returns, and this will block indefinitely
196+
/// if the pool is leaked.
197+
pub fn spawn(
198+
self,
199+
spawn: impl FnMut(ThreadBuilder) -> io::Result<()>,
200+
) -> Result<ThreadPool, ThreadPoolBuildError> {
201+
ThreadPool::build_spawn(self, spawn)
202+
}
203+
188204
/// Initializes the global thread pool. This initialization is
189205
/// **optional**. If you do not call this function, the thread pool
190206
/// will be automatically initialized with the default
@@ -208,6 +224,21 @@ impl ThreadPoolBuilder {
208224
Ok(())
209225
}
210226

227+
/// Initializes the global thread pool using a custom function for spawning
228+
/// threads.
229+
///
230+
/// Note that the global thread pool doesn't terminate until the entire process
231+
/// exits! If this is used with something like `crossbeam::scope` that tries to
232+
/// join threads, that will block indefinitely.
233+
pub fn spawn_global(
234+
self,
235+
spawn: impl FnMut(ThreadBuilder) -> io::Result<()>,
236+
) -> Result<(), ThreadPoolBuildError> {
237+
let registry = registry::spawn_global_registry(self, spawn)?;
238+
registry.wait_until_primed();
239+
Ok(())
240+
}
241+
211242
/// Get the number of threads that will be used for the thread
212243
/// pool. See `num_threads()` for more information.
213244
fn get_num_threads(&self) -> usize {

rayon-core/src/registry.rs

Lines changed: 107 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ use sleep::Sleep;
1111
use std::any::Any;
1212
use std::cell::Cell;
1313
use std::collections::hash_map::DefaultHasher;
14+
use std::fmt;
1415
use std::hash::Hasher;
16+
use std::io;
1517
use std::mem;
1618
use std::ptr;
1719
#[allow(deprecated)]
@@ -24,6 +26,49 @@ use unwind;
2426
use util::leak;
2527
use {ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder};
2628

29+
/// Thread builder used for customization via `ThreadPoolBuilder::spawn`.
30+
pub struct ThreadBuilder {
31+
name: Option<String>,
32+
stack_size: Option<usize>,
33+
worker: Worker<JobRef>,
34+
registry: Arc<Registry>,
35+
index: usize,
36+
}
37+
38+
impl ThreadBuilder {
39+
/// Get the index of this thread in the pool, within `0..num_threads`.
40+
pub fn index(&self) -> usize {
41+
self.index
42+
}
43+
44+
/// Get the string that was specified by `ThreadPoolBuilder::name()`.
45+
pub fn name(&self) -> Option<&str> {
46+
self.name.as_ref().map(String::as_str)
47+
}
48+
49+
/// Get the value that was specified by `ThreadPoolBuilder::stack_size()`.
50+
pub fn stack_size(&self) -> Option<usize> {
51+
self.stack_size
52+
}
53+
54+
/// Execute the main loop for this thread. This will not return until the
55+
/// thread pool is dropped.
56+
pub fn run(self) {
57+
unsafe { main_loop(self.worker, self.registry, self.index) }
58+
}
59+
}
60+
61+
impl fmt::Debug for ThreadBuilder {
62+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
63+
f.debug_struct("ThreadBuilder")
64+
.field("pool", &self.registry.id())
65+
.field("index", &self.index)
66+
.field("name", &self.name)
67+
.field("stack_size", &self.stack_size)
68+
.finish()
69+
}
70+
}
71+
2772
pub(super) struct Registry {
2873
thread_infos: Vec<ThreadInfo>,
2974
sleep: Sleep,
@@ -58,39 +103,50 @@ static THE_REGISTRY_SET: Once = ONCE_INIT;
58103
/// initialization has not already occurred, use the default
59104
/// configuration.
60105
fn global_registry() -> &'static Arc<Registry> {
61-
THE_REGISTRY_SET.call_once(|| unsafe { init_registry(ThreadPoolBuilder::new()).unwrap() });
62-
unsafe { THE_REGISTRY.expect("The global thread pool has not been initialized.") }
106+
unsafe {
107+
THE_REGISTRY.unwrap_or_else(|| {
108+
init_global_registry(ThreadPoolBuilder::new())
109+
.expect("The global thread pool has not been initialized.")
110+
})
111+
}
63112
}
64113

65114
/// Starts the worker threads (if that has not already happened) with
66115
/// the given builder.
67116
pub(super) fn init_global_registry(
68117
builder: ThreadPoolBuilder,
69-
) -> Result<&'static Registry, ThreadPoolBuildError> {
70-
let mut called = false;
71-
let mut init_result = Ok(());;
72-
THE_REGISTRY_SET.call_once(|| unsafe {
73-
init_result = init_registry(builder);
74-
called = true;
75-
});
76-
if called {
77-
init_result?;
78-
Ok(&**global_registry())
79-
} else {
80-
Err(ThreadPoolBuildError::new(
81-
ErrorKind::GlobalPoolAlreadyInitialized,
82-
))
83-
}
118+
) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> {
119+
set_global_registry(|| Registry::new(builder))
120+
}
121+
122+
/// Starts the worker threads (if that has not already happened) with
123+
/// the given builder.
124+
pub(super) fn spawn_global_registry(
125+
builder: ThreadPoolBuilder,
126+
spawn: impl FnMut(ThreadBuilder) -> io::Result<()>,
127+
) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> {
128+
set_global_registry(|| Registry::spawn(builder, spawn))
84129
}
85130

86-
/// Initializes the global registry with the given builder.
87-
/// Meant to be called from within the `THE_REGISTRY_SET` once
88-
/// function. Declared `unsafe` because it writes to `THE_REGISTRY` in
89-
/// an unsynchronized fashion.
90-
unsafe fn init_registry(builder: ThreadPoolBuilder) -> Result<(), ThreadPoolBuildError> {
91-
let registry = Registry::new(builder)?;
92-
THE_REGISTRY = Some(leak(registry));
93-
Ok(())
131+
/// Starts the worker threads (if that has not already happened)
132+
/// by creating a registry with the given callback.
133+
fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
134+
where
135+
F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
136+
{
137+
let mut result = Err(ThreadPoolBuildError::new(
138+
ErrorKind::GlobalPoolAlreadyInitialized,
139+
));
140+
THE_REGISTRY_SET.call_once(|| {
141+
result = registry().map(|registry| {
142+
let registry = leak(registry);
143+
unsafe {
144+
THE_REGISTRY = Some(registry);
145+
}
146+
registry
147+
});
148+
});
149+
result
94150
}
95151

96152
struct Terminator<'a>(&'a Arc<Registry>);
@@ -102,7 +158,24 @@ impl<'a> Drop for Terminator<'a> {
102158
}
103159

104160
impl Registry {
105-
pub(super) fn new(mut builder: ThreadPoolBuilder) -> Result<Arc<Self>, ThreadPoolBuildError> {
161+
pub(super) fn new(builder: ThreadPoolBuilder) -> Result<Arc<Self>, ThreadPoolBuildError> {
162+
Registry::spawn(builder, |thread| {
163+
let mut b = thread::Builder::new();
164+
if let Some(ref name) = thread.name {
165+
b = b.name(name.clone());
166+
}
167+
if let Some(stack_size) = thread.stack_size {
168+
b = b.stack_size(stack_size);
169+
}
170+
b.spawn(|| thread.run())?;
171+
Ok(())
172+
})
173+
}
174+
175+
pub(super) fn spawn(
176+
mut builder: ThreadPoolBuilder,
177+
mut spawn: impl FnMut(ThreadBuilder) -> io::Result<()>,
178+
) -> Result<Arc<Self>, ThreadPoolBuildError> {
106179
let n_threads = builder.get_num_threads();
107180
let breadth_first = builder.get_breadth_first();
108181

@@ -130,15 +203,14 @@ impl Registry {
130203
let t1000 = Terminator(&registry);
131204

132205
for (index, worker) in workers.into_iter().enumerate() {
133-
let registry = registry.clone();
134-
let mut b = thread::Builder::new();
135-
if let Some(name) = builder.get_thread_name(index) {
136-
b = b.name(name);
137-
}
138-
if let Some(stack_size) = builder.get_stack_size() {
139-
b = b.stack_size(stack_size);
140-
}
141-
if let Err(e) = b.spawn(move || unsafe { main_loop(worker, registry, index) }) {
206+
let thread = ThreadBuilder {
207+
name: builder.get_thread_name(index),
208+
stack_size: builder.get_stack_size(),
209+
registry: registry.clone(),
210+
worker,
211+
index,
212+
};
213+
if let Err(e) = spawn(thread) {
142214
return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
143215
}
144216
}

rayon-core/src/thread_pool/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ use registry::{Registry, WorkerThread};
88
use spawn;
99
use std::error::Error;
1010
use std::fmt;
11+
use std::io;
1112
use std::sync::Arc;
1213
#[allow(deprecated)]
1314
use Configuration;
1415
use {scope, Scope};
1516
use {scope_fifo, ScopeFifo};
16-
use {ThreadPoolBuildError, ThreadPoolBuilder};
17+
use {ThreadBuilder, ThreadPoolBuildError, ThreadPoolBuilder};
1718

1819
mod internal;
1920
mod test;
@@ -66,6 +67,14 @@ impl ThreadPool {
6667
Ok(ThreadPool { registry })
6768
}
6869

70+
pub(super) fn build_spawn(
71+
builder: ThreadPoolBuilder,
72+
spawn: impl FnMut(ThreadBuilder) -> io::Result<()>,
73+
) -> Result<ThreadPool, ThreadPoolBuildError> {
74+
let registry = Registry::spawn(builder, spawn)?;
75+
Ok(ThreadPool { registry })
76+
}
77+
6978
/// Returns a handle to the global thread pool. This is the pool
7079
/// that Rayon will use by default when you perform a `join()` or
7180
/// `scope()` operation, if no other thread-pool is installed. If

rayon-core/tests/scoped_threadpool.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
extern crate crossbeam_utils;
2+
extern crate rayon_core;
3+
4+
#[macro_use]
5+
extern crate scoped_tls;
6+
7+
use crossbeam_utils::thread;
8+
use rayon_core::ThreadPoolBuilder;
9+
10+
#[derive(PartialEq, Eq, Debug)]
11+
struct Local(i32);
12+
13+
scoped_thread_local!(static LOCAL: Local);
14+
15+
#[test]
16+
fn scoped_tls_missing() {
17+
LOCAL.set(&Local(42), || {
18+
let pool = ThreadPoolBuilder::new()
19+
.build()
20+
.expect("thread pool created");
21+
22+
// `LOCAL` is not set in the pool.
23+
pool.install(|| {
24+
assert!(!LOCAL.is_set());
25+
});
26+
});
27+
}
28+
29+
#[test]
30+
fn scoped_tls_threadpool() {
31+
LOCAL.set(&Local(42), || {
32+
LOCAL.with(|x| {
33+
thread::scope(|scope| {
34+
let pool = ThreadPoolBuilder::new()
35+
.spawn(move |thread| {
36+
scope
37+
.builder()
38+
.spawn(move |_| {
39+
// Borrow the same local value in the thread pool.
40+
LOCAL.set(x, || thread.run())
41+
})
42+
.map(|_| ())
43+
})
44+
.expect("thread pool created");
45+
46+
// The pool matches our local value.
47+
pool.install(|| {
48+
assert!(LOCAL.is_set());
49+
LOCAL.with(|y| {
50+
assert_eq!(x, y);
51+
});
52+
});
53+
54+
// If we change our local value, the pool is not affected.
55+
LOCAL.set(&Local(-1), || {
56+
pool.install(|| {
57+
assert!(LOCAL.is_set());
58+
LOCAL.with(|y| {
59+
assert_eq!(x, y);
60+
});
61+
});
62+
});
63+
})
64+
.expect("scope threads ok");
65+
// `thread::scope` will wait for the threads to exit before returning.
66+
});
67+
});
68+
}

0 commit comments

Comments
 (0)