Skip to content

Commit 6172270

Browse files
Make a async init function required for each thread
1 parent 5412ef0 commit 6172270

File tree

2 files changed

+78
-25
lines changed

2 files changed

+78
-25
lines changed

crates/cubecl-wgpu/Cargo.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ web-time = { workspace = true }
4646

4747
cfg-if = { workspace = true }
4848

49-
[target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies]
50-
wasm-bindgen = { workspace = true }
51-
5249
[dev-dependencies]
5350
cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [
5451
"export_tests",

crates/cubecl-wgpu/src/runtime.rs

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ impl Runtime for WgpuRuntime<WgslCompiler> {
6868
{
6969
let server = LOCAL_DEVICE.with_borrow_mut(|runtime| {
7070
runtime
71-
.entry(device.clone())
72-
.or_insert_with(|| ThreadLocalChannel::make_server(device))
71+
.get(device)
72+
.expect(&format!("The wgpu server for {device:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread"))
7373
.clone()
7474
});
7575
let server = server.borrow();
@@ -158,6 +158,32 @@ pub struct WgpuSetup {
158158
pub queue: Pdrc<wgpu::Queue>,
159159
}
160160

161+
#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))]
162+
pub async fn init_thread_server(device: WgpuDevice, options: RuntimeOptions) {
163+
let setup = create_setup_for_device::<AutoGraphicsApi, WgslCompiler>(&device).await;
164+
165+
let limits = setup.device.limits();
166+
let mem_props = MemoryDeviceProperties {
167+
max_page_size: limits.max_storage_buffer_binding_size as u64,
168+
alignment: WgpuStorage::ALIGNMENT.max(limits.min_storage_buffer_offset_alignment as u64),
169+
};
170+
let memory_management = {
171+
let mem_props = mem_props.clone();
172+
let config = options.memory_config;
173+
let storage = WgpuStorage::new(setup.device.clone());
174+
MemoryManagement::from_configuration(storage, mem_props, config)
175+
};
176+
let server = crate::compute::WgpuServer::new(
177+
memory_management,
178+
setup.device,
179+
setup.queue,
180+
setup.adapter,
181+
options.tasks_max,
182+
);
183+
184+
LOCAL_DEVICE.with_borrow_mut(|map| map.insert(device, Rc::new(RefCell::new(server))));
185+
}
186+
161187
/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
162188
/// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries.
163189
#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))]
@@ -465,26 +491,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
465491
) -> cubecl_runtime::storage::BindingResource<Server> {
466492
LOCAL_DEVICE.with_borrow_mut(|runtime| {
467493
let server = runtime
468-
.entry(self.device.clone())
469-
.or_insert_with(|| Self::make_server(&self.device));
494+
.get(&self.device)
495+
.expect(&format!(
496+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
497+
self.device,
498+
));
470499
server.borrow_mut().get_resource(binding)
471500
})
472501
}
473502

474503
fn create(&self, data: &[u8]) -> cubecl_core::server::Handle {
475504
LOCAL_DEVICE.with_borrow_mut(|runtime| {
476505
let server = runtime
477-
.entry(self.device.clone())
478-
.or_insert_with(|| Self::make_server(&self.device));
506+
.get(&self.device)
507+
.expect(&format!(
508+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
509+
self.device,
510+
));
479511
server.borrow_mut().create(data)
480512
})
481513
}
482514

483515
fn empty(&self, size: usize) -> cubecl_core::server::Handle {
484516
LOCAL_DEVICE.with_borrow_mut(|runtime| {
485517
let server = runtime
486-
.entry(self.device.clone())
487-
.or_insert_with(|| Self::make_server(&self.device));
518+
.get(&self.device)
519+
.expect(&format!(
520+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
521+
self.device,
522+
));
488523
server.borrow_mut().empty(size)
489524
})
490525
}
@@ -498,26 +533,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
498533
) {
499534
LOCAL_DEVICE.with_borrow_mut(|runtime| {
500535
let server = runtime
501-
.entry(self.device.clone())
502-
.or_insert_with(|| Self::make_server(&self.device));
536+
.get(&self.device)
537+
.expect(&format!(
538+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
539+
self.device,
540+
));
503541
unsafe { server.borrow_mut().execute(kernel, count, bindings, mode) }
504542
})
505543
}
506544

507545
fn flush(&self) {
508546
LOCAL_DEVICE.with_borrow_mut(|runtime| {
509547
let server = runtime
510-
.entry(self.device.clone())
511-
.or_insert_with(|| Self::make_server(&self.device));
548+
.get(&self.device)
549+
.expect(&format!(
550+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
551+
self.device,
552+
));
512553
server.borrow_mut().flush()
513554
})
514555
}
515556

516557
fn sync(&self) -> impl std::future::Future<Output = ()> {
517558
LOCAL_DEVICE.with_borrow_mut(|runtime| {
518559
let server = runtime
519-
.entry(self.device.clone())
520-
.or_insert_with(|| Self::make_server(&self.device))
560+
.get(&self.device)
561+
.expect(&format!(
562+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
563+
self.device,
564+
))
521565
.clone();
522566
async move { server.borrow_mut().sync().await }
523567
})
@@ -526,8 +570,11 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
526570
fn sync_elapsed(&self) -> impl std::future::Future<Output = cubecl_runtime::TimestampsResult> {
527571
LOCAL_DEVICE.with_borrow_mut(|runtime| {
528572
let server = runtime
529-
.entry(self.device.clone())
530-
.or_insert_with(|| Self::make_server(&self.device))
573+
.get(&self.device)
574+
.expect(&format!(
575+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
576+
self.device,
577+
))
531578
.clone();
532579
async move { server.borrow_mut().sync_elapsed().await }
533580
})
@@ -536,26 +583,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
536583
fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage {
537584
LOCAL_DEVICE.with_borrow_mut(|runtime| {
538585
let server = runtime
539-
.entry(self.device.clone())
540-
.or_insert_with(|| Self::make_server(&self.device));
586+
.get(&self.device)
587+
.expect(&format!(
588+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
589+
self.device,
590+
));
541591
server.borrow_mut().memory_usage()
542592
})
543593
}
544594

545595
fn enable_timestamps(&self) {
546596
LOCAL_DEVICE.with_borrow_mut(|runtime| {
547597
let server = runtime
548-
.entry(self.device.clone())
549-
.or_insert_with(|| Self::make_server(&self.device));
598+
.get(&self.device)
599+
.expect(&format!(
600+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
601+
self.device,
602+
));
550603
server.borrow_mut().enable_timestamps()
551604
})
552605
}
553606

554607
fn disable_timestamps(&self) {
555608
LOCAL_DEVICE.with_borrow_mut(|runtime| {
556609
let server = runtime
557-
.entry(self.device.clone())
558-
.or_insert_with(|| Self::make_server(&self.device));
610+
.get(&self.device)
611+
.expect(&format!(
612+
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
613+
self.device,
614+
));
559615
server.borrow_mut().disable_timestamps()
560616
})
561617
}

0 commit comments

Comments
 (0)