Skip to content

Commit cfe243e

Browse files
Added back the send bound
1 parent feab17f commit cfe243e

File tree

3 files changed

+104
-13
lines changed

3 files changed

+104
-13
lines changed

crates/cubecl-runtime/src/channel/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use alloc::vec::Vec;
1010

1111
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
1212
/// while ensuring thread-safety
13-
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {
13+
pub trait ComputeChannel<Server: ComputeServer>: Send + Clone + core::fmt::Debug {
1414
/// Given a binding, returns owned resource as bytes
1515
fn read(&self, binding: Binding) -> impl Future<Output = Vec<u8>>;
1616

crates/cubecl-runtime/src/channel/cell.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,8 @@ where
108108
self.server.borrow_mut().disable_timestamps();
109109
}
110110
}
111+
112+
/// This is unsafe, since no concurrency is supported by the `RefCell` channel.
113+
/// However using this channel should only be done in single threaded environments such as `no-std`.
114+
unsafe impl<Server: ComputeServer> Send for RefCellComputeChannel<Server> {}
115+
unsafe impl<Server: ComputeServer> Sync for RefCellComputeChannel<Server> {}

crates/cubecl-wgpu/src/runtime.rs

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -414,18 +414,56 @@ fn get_device_override() -> Option<WgpuDevice> {
414414
})
415415
}
416416

417+
#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))]
417418
#[derive(Debug, Clone)]
418419
pub struct ThreadLocalChannel {
419420
device: WgpuDevice,
420421
}
421422

423+
#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))]
424+
impl ThreadLocalChannel {
425+
fn make_server(device: &WgpuDevice) -> Rc<RefCell<Server>> {
426+
let setup = future::block_on(create_setup_for_device::<AutoGraphicsApi, WgslCompiler>(
427+
device,
428+
));
429+
430+
let limits = setup.device.limits();
431+
let mem_props = MemoryDeviceProperties {
432+
max_page_size: limits.max_storage_buffer_binding_size as u64,
433+
alignment: WgpuStorage::ALIGNMENT
434+
.max(limits.min_storage_buffer_offset_alignment as u64),
435+
};
436+
437+
let options = RuntimeOptions::default();
438+
let memory_management = {
439+
let mem_props = mem_props.clone();
440+
let config = options.memory_config;
441+
let storage = WgpuStorage::new(setup.device.clone());
442+
MemoryManagement::from_configuration(storage, mem_props, config)
443+
};
444+
let server = crate::compute::WgpuServer::new(
445+
memory_management,
446+
setup.device.clone(),
447+
setup.queue,
448+
options.tasks_max,
449+
);
450+
451+
Rc::new(RefCell::new(server))
452+
}
453+
}
454+
455+
#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))]
422456
impl ComputeChannel<Server> for ThreadLocalChannel {
423457
fn read(
424458
&self,
425459
binding: cubecl_core::server::Binding,
426460
) -> impl std::future::Future<Output = Vec<u8>> {
427461
LOCAL_RUNTIME.with(|runtime| {
428-
let server = runtime.borrow()[&self.device].clone();
462+
let server = runtime
463+
.borrow_mut()
464+
.entry(self.device.clone())
465+
.or_insert_with(|| Self::make_server(&self.device))
466+
.clone();
429467
async move { server.borrow_mut().read(binding).await }
430468
})
431469
}
@@ -435,18 +473,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
435473
binding: cubecl_core::server::Binding,
436474
) -> cubecl_runtime::storage::BindingResource<Server> {
437475
LOCAL_RUNTIME.with(|runtime| {
438-
runtime.borrow()[&self.device]
476+
runtime
477+
.borrow_mut()
478+
.entry(self.device.clone())
479+
.or_insert_with(|| Self::make_server(&self.device))
439480
.borrow_mut()
440481
.get_resource(binding)
441482
})
442483
}
443484

444485
fn create(&self, data: &[u8]) -> cubecl_core::server::Handle {
445-
LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().create(data))
486+
LOCAL_RUNTIME.with(|runtime| {
487+
runtime
488+
.borrow_mut()
489+
.entry(self.device.clone())
490+
.or_insert_with(|| Self::make_server(&self.device))
491+
.borrow_mut()
492+
.create(data)
493+
})
446494
}
447495

448496
fn empty(&self, size: usize) -> cubecl_core::server::Handle {
449-
LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().empty(size))
497+
LOCAL_RUNTIME.with(|runtime| {
498+
runtime
499+
.borrow_mut()
500+
.entry(self.device.clone())
501+
.or_insert_with(|| Self::make_server(&self.device))
502+
.borrow_mut()
503+
.empty(size)
504+
})
450505
}
451506

452507
unsafe fn execute(
@@ -457,45 +512,76 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
457512
mode: cubecl_core::ExecutionMode,
458513
) {
459514
LOCAL_RUNTIME.with(|runtime| {
460-
let runtime = runtime.borrow();
461-
let mut server = runtime[&self.device].borrow_mut();
515+
let mut runtime = runtime.borrow_mut();
516+
let mut server = runtime
517+
.entry(self.device.clone())
518+
.or_insert_with(|| Self::make_server(&self.device))
519+
.borrow_mut();
462520
unsafe { server.execute(kernel, count, bindings, mode) }
463521
})
464522
}
465523

466524
fn flush(&self) {
467-
LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().flush())
525+
LOCAL_RUNTIME.with(|runtime| {
526+
runtime
527+
.borrow_mut()
528+
.entry(self.device.clone())
529+
.or_insert_with(|| Self::make_server(&self.device))
530+
.borrow_mut()
531+
.flush()
532+
})
468533
}
469534

470535
fn sync(&self) -> impl std::future::Future<Output = ()> {
471536
LOCAL_RUNTIME.with(|runtime| {
472-
let server = runtime.borrow()[&self.device].clone();
537+
let server = runtime
538+
.borrow_mut()
539+
.entry(self.device.clone())
540+
.or_insert_with(|| Self::make_server(&self.device))
541+
.clone();
473542
async move { server.borrow_mut().sync().await }
474543
})
475544
}
476545

477546
fn sync_elapsed(&self) -> impl std::future::Future<Output = cubecl_runtime::TimestampsResult> {
478547
LOCAL_RUNTIME.with(|runtime| {
479-
let server = runtime.borrow()[&self.device].clone();
548+
let server = runtime
549+
.borrow_mut()
550+
.entry(self.device.clone())
551+
.or_insert_with(|| Self::make_server(&self.device))
552+
.clone();
480553
async move { server.borrow_mut().sync_elapsed().await }
481554
})
482555
}
483556

484557
fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage {
485-
LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().memory_usage())
558+
LOCAL_RUNTIME.with(|runtime| {
559+
runtime
560+
.borrow_mut()
561+
.entry(self.device.clone())
562+
.or_insert_with(|| Self::make_server(&self.device))
563+
.borrow_mut()
564+
.memory_usage()
565+
})
486566
}
487567

488568
fn enable_timestamps(&self) {
489569
LOCAL_RUNTIME.with(|runtime| {
490-
runtime.borrow()[&self.device]
570+
runtime
571+
.borrow_mut()
572+
.entry(self.device.clone())
573+
.or_insert_with(|| Self::make_server(&self.device))
491574
.borrow_mut()
492575
.enable_timestamps()
493576
})
494577
}
495578

496579
fn disable_timestamps(&self) {
497580
LOCAL_RUNTIME.with(|runtime| {
498-
runtime.borrow()[&self.device]
581+
runtime
582+
.borrow_mut()
583+
.entry(self.device.clone())
584+
.or_insert_with(|| Self::make_server(&self.device))
499585
.borrow_mut()
500586
.disable_timestamps()
501587
})

0 commit comments

Comments
 (0)