diff --git a/Cargo.toml b/Cargo.toml index 205aac918..8c2353b1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ async-channel = "2.3" dirs = "5.0.1" md5 = "0.7.0" sanitize-filename = "0.5" +wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4.45" weak-table = "0.3" web-time = "1.1.0" diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index 2ab8a9759..5f5710776 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -10,9 +10,9 @@ use alloc::vec::Vec; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety -pub trait ComputeChannel: Clone + core::fmt::Debug + Send + Sync { +pub trait ComputeChannel: Sync + Send + Clone + core::fmt::Debug { /// Given a binding, returns owned resource as bytes - fn read(&self, binding: Binding) -> impl Future> + Send; + fn read(&self, binding: Binding) -> impl Future>; /// Given a resource handle, return the storage resource. fn get_resource(&self, binding: Binding) -> BindingResource; @@ -40,12 +40,12 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send fn flush(&self); /// Wait for the completion of every task in the server. - fn sync(&self) -> impl Future + Send; + fn sync(&self) -> impl Future; /// Wait for the completion of every task in the server. /// /// Returns the (approximate) total amount of GPU work done since the last sync. - fn sync_elapsed(&self) -> impl Future + Send; + fn sync_elapsed(&self) -> impl Future; /// Get the current memory usage of the server. fn memory_usage(&self) -> crate::memory_management::MemoryUsage; diff --git a/crates/cubecl-runtime/src/channel/cell.rs b/crates/cubecl-runtime/src/channel/cell.rs index 83e99a282..98a8bb83a 100644 --- a/crates/cubecl-runtime/src/channel/cell.rs +++ b/crates/cubecl-runtime/src/channel/cell.rs @@ -42,7 +42,7 @@ where impl ComputeChannel for RefCellComputeChannel where - Server: ComputeServer + Send, + Server: ComputeServer, { async fn read(&self, binding: Binding) -> Vec { let future = { diff --git a/crates/cubecl-runtime/src/channel/mpsc.rs b/crates/cubecl-runtime/src/channel/mpsc.rs index 1b1a9e546..66a3cda6a 100644 --- a/crates/cubecl-runtime/src/channel/mpsc.rs +++ b/crates/cubecl-runtime/src/channel/mpsc.rs @@ -6,7 +6,7 @@ use super::ComputeChannel; use crate::{ memory_management::MemoryUsage, server::{Binding, ComputeServer, CubeCount, Handle}, - storage::BindingResource, + storage::{BindingResource, ComputeStorage}, ExecutionMode, }; @@ -50,7 +50,8 @@ where impl MpscComputeChannel where - Server: ComputeServer + 'static, + Server: ComputeServer + Send + 'static, + ::Resource: Send, { /// Create a new mpsc compute channel. pub fn new(mut server: Server) -> Self { @@ -123,6 +124,7 @@ impl Clone for MpscComputeChannel { impl ComputeChannel for MpscComputeChannel where Server: ComputeServer + 'static, + ::Resource: Send, { async fn read(&self, binding: Binding) -> Vec { let sender = self.state.sender.clone(); diff --git a/crates/cubecl-runtime/src/channel/mutex.rs b/crates/cubecl-runtime/src/channel/mutex.rs index 8d4dff53f..adbc2a07d 100644 --- a/crates/cubecl-runtime/src/channel/mutex.rs +++ b/crates/cubecl-runtime/src/channel/mutex.rs @@ -35,7 +35,7 @@ where impl ComputeChannel for MutexComputeChannel where - Server: ComputeServer, + Server: ComputeServer + Send, { async fn read(&self, handle: Binding) -> Vec { // Nb: The order here is really important - the mutex guard has to be dropped before diff --git a/crates/cubecl-runtime/src/memory_management/base.rs b/crates/cubecl-runtime/src/memory_management/base.rs index 798e670be..feaa986b4 100644 --- a/crates/cubecl-runtime/src/memory_management/base.rs +++ b/crates/cubecl-runtime/src/memory_management/base.rs @@ -4,6 +4,7 @@ use alloc::{format, string::String}; /// Amount of memory in use by this allocator /// and statistics on how much memory is reserved and /// wasted in total. +#[derive(Debug)] pub struct MemoryUsage { /// The number of allocations currently active. pub number_allocs: u64, diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 45a06a8c9..027cd69fd 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -14,7 +14,7 @@ use cubecl_common::benchmark::TimestampsResult; /// /// Everything in the server is mutable, therefore it should be solely accessed through the /// [compute channel](crate::channel::ComputeChannel) for thread safety. -pub trait ComputeServer: Send + core::fmt::Debug +pub trait ComputeServer: core::fmt::Debug where Self: Sized, { @@ -26,7 +26,7 @@ where type Feature: Ord + Copy + Debug + Send + Sync; /// Given a handle, returns the owned resource as bytes. - fn read(&mut self, binding: Binding) -> impl Future> + Send + 'static; + fn read(&mut self, binding: Binding) -> impl Future> + 'static; /// Given a resource handle, returns the storage resource. fn get_resource(&mut self, binding: Binding) -> BindingResource; @@ -57,12 +57,12 @@ where fn flush(&mut self); /// Wait for the completion of every task in the server. - fn sync(&mut self) -> impl Future + Send + 'static; + fn sync(&mut self) -> impl Future + 'static; /// Wait for the completion of every task in the server. /// /// Returns the (approximate) total amount of GPU work done since the last sync. - fn sync_elapsed(&mut self) -> impl Future + Send + 'static; + fn sync_elapsed(&mut self) -> impl Future + 'static; /// The current memory usage of the server. fn memory_usage(&self) -> MemoryUsage; diff --git a/crates/cubecl-runtime/src/storage/base.rs b/crates/cubecl-runtime/src/storage/base.rs index 0f50eb8da..9737404c2 100644 --- a/crates/cubecl-runtime/src/storage/base.rs +++ b/crates/cubecl-runtime/src/storage/base.rs @@ -63,10 +63,10 @@ impl StorageHandle { } /// Storage types are responsible for allocating and deallocating memory. -pub trait ComputeStorage: Send { +pub trait ComputeStorage { /// The resource associated type determines the way data is implemented and how /// it can be accessed by kernels. - type Resource: Send; + type Resource; /// The alignment memory is allocated with in this storage. const ALIGNMENT: u64; diff --git a/crates/cubecl-runtime/src/tune/tune_cache.rs b/crates/cubecl-runtime/src/tune/tune_cache.rs index a2c5e0d5f..223d46b63 100644 --- a/crates/cubecl-runtime/src/tune/tune_cache.rs +++ b/crates/cubecl-runtime/src/tune/tune_cache.rs @@ -117,11 +117,12 @@ impl TuneCache { } => { if cfg!(autotune_persistent_cache) { match checksum_matches { - None => TuneCacheResult::Unchecked, // Don't know yet. - Some(false) => TuneCacheResult::Miss, // Can't use this. + #[cfg(autotune_persistent_cache)] + None => TuneCacheResult::Unchecked, // Don't know yet. Some(true) => TuneCacheResult::Hit { fastest_index: *fastest_index, }, + _ => TuneCacheResult::Miss, // Some(false) or None so we can't use this. } } else { let _ = checksum_matches; diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index b3aacc331..4854bc9d7 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -224,7 +224,7 @@ impl Tuner { } } -fn spawn_benchmark_task(future: impl Future + Send + 'static) { +fn spawn_benchmark_task(future: impl Future + 'static) { // On wasm, spawn the tuning as a detached task. #[cfg(target_family = "wasm")] wasm_bindgen_futures::spawn_local(future); diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index 1f03dda59..1169a727e 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -39,6 +39,7 @@ wgpu = { version = "22.0.0", features = ["fragile-send-sync-non-atomic-wasm"] } async-channel = { workspace = true } derive-new = { workspace = true } +futures-lite = { workspace = true } hashbrown = { workspace = true } log = { workspace = true } web-time = { workspace = true } diff --git a/crates/cubecl-wgpu/src/compiler/base.rs b/crates/cubecl-wgpu/src/compiler/base.rs index e040190de..05ca30eb9 100644 --- a/crates/cubecl-wgpu/src/compiler/base.rs +++ b/crates/cubecl-wgpu/src/compiler/base.rs @@ -1,12 +1,10 @@ -use std::sync::Arc; - use cubecl_core::{ prelude::CompiledKernel, server::ComputeServer, Compiler, ExecutionMode, Feature, }; use cubecl_runtime::DeviceProperties; use wgpu::{Adapter, ComputePipeline, Device, Queue}; -use crate::WgpuServer; +use crate::{Pdrc, WgpuServer}; pub trait WgpuCompiler: Compiler { fn compile( @@ -19,7 +17,7 @@ pub trait WgpuCompiler: Compiler { server: &mut WgpuServer, kernel: CompiledKernel, mode: ExecutionMode, - ) -> Arc; + ) -> Pdrc; #[allow(async_fn_in_trait)] async fn request_device(adapter: &Adapter) -> (Device, Queue); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 5060466e9..a00395db9 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -1,10 +1,10 @@ -use std::{borrow::Cow, sync::Arc}; +use std::borrow::Cow; use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory}; use super::{LocalArray, Subgroup}; use crate::{ compiler::{base::WgpuCompiler, wgsl}, - WgpuServer, + Pdrc, WgpuServer, }; use cubecl_core::{ ir::{self as cube, HybridAllocator, UIntKind}, @@ -73,7 +73,7 @@ impl WgpuCompiler for WgslCompiler { server: &mut WgpuServer, kernel: CompiledKernel, mode: ExecutionMode, - ) -> Arc { + ) -> Pdrc { let source = &kernel.source; let repr = kernel.repr.unwrap(); let module = match mode { @@ -118,7 +118,7 @@ impl WgpuCompiler for WgslCompiler { push_constant_ranges: &[], }); - Arc::new( + Pdrc::new( server .device .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { diff --git a/crates/cubecl-wgpu/src/compute/poll.rs b/crates/cubecl-wgpu/src/compute/poll.rs index 1bc253c29..cf8f5d19a 100644 --- a/crates/cubecl-wgpu/src/compute/poll.rs +++ b/crates/cubecl-wgpu/src/compute/poll.rs @@ -58,10 +58,12 @@ mod _impl { // On Wasm, the browser handles the polling loop, so we don't need anything. #[cfg(target_family = "wasm")] mod _impl { + use crate::Pdrc; + #[derive(Debug)] pub struct WgpuPoll {} impl WgpuPoll { - pub fn new(_device: alloc::sync::Arc) -> Self { + pub fn new(_device: Pdrc) -> Self { Self {} } pub fn start_polling(&self) -> alloc::sync::Arc<()> { diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 0efe93291..6a3dbec4e 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -2,8 +2,7 @@ use std::{future::Future, marker::PhantomData, num::NonZero, pin::Pin, time::Dur use super::poll::WgpuPoll; use super::WgpuStorage; -use crate::compiler::base::WgpuCompiler; -use alloc::sync::Arc; +use crate::{compiler::base::WgpuCompiler, Pdrc}; use cubecl_common::future; use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, Feature, KernelId}; use cubecl_runtime::{ @@ -15,18 +14,23 @@ use cubecl_runtime::{ }; use hashbrown::HashMap; use web_time::Instant; -use wgpu::{CommandEncoder, ComputePass, ComputePipeline, QuerySet, QuerySetDescriptor, QueryType}; +use wgpu::{ + CommandEncoder, ComputePass, ComputePipeline, QuerySet, QuerySetDescriptor, QueryType, + WasmNotSend, +}; /// Wgpu compute server. #[derive(Debug)] pub struct WgpuServer { memory_management: MemoryManagement, - pub(crate) device: Arc, - queue: Arc, + pub(crate) device: Pdrc, + queue: Pdrc, + #[allow(unused)] + pub(crate) adapter: Pdrc, encoder: CommandEncoder, current_pass: Option>, tasks_count: usize, - pipelines: HashMap>, + pipelines: HashMap>, tasks_max: usize, logger: DebugLogger, poll: WgpuPoll, @@ -36,6 +40,10 @@ pub struct WgpuServer { _compiler: PhantomData, } +trait FutureWasmNotSend: Future + WasmNotSend {} + +impl + WasmNotSend> FutureWasmNotSend for T {} + #[derive(Debug)] enum KernelTimestamps { Native { query_set: QuerySet, init: bool }, @@ -82,8 +90,9 @@ impl WgpuServer { /// Create a new server. pub fn new( memory_management: MemoryManagement, - device: Arc, - queue: Arc, + device: Pdrc, + queue: Pdrc, + adapter: Pdrc, tasks_max: usize, ) -> Self { let logger = DebugLogger::default(); @@ -93,18 +102,22 @@ impl WgpuServer { timestamps.enable(&device); } + let encoder = create_encoder(&device); + let poll = WgpuPoll::new(device.clone()); + Self { memory_management, - device: device.clone(), - queue: queue.clone(), - encoder: create_encoder(&device), + device, + queue, + adapter, + encoder, current_pass: None, tasks_count: 0, storage_locked: MemoryLock::default(), pipelines: HashMap::new(), tasks_max, logger, - poll: WgpuPoll::new(device.clone()), + poll, duration_profiled: None, timestamps, _compiler: PhantomData, @@ -113,9 +126,9 @@ impl WgpuServer { fn pipeline( &mut self, - kernel: ::Kernel, + kernel: as ComputeServer>::Kernel, mode: ExecutionMode, - ) -> Arc { + ) -> Pdrc { let mut kernel_id = kernel.id(); kernel_id.mode(mode); @@ -192,7 +205,7 @@ impl WgpuServer { } } - fn sync_queue(&mut self) -> Pin + Send + 'static>> { + fn sync_queue(&mut self) -> Pin + 'static>> { self.flush(); #[cfg(target_family = "wasm")] @@ -220,7 +233,7 @@ impl WgpuServer { fn sync_queue_elapsed( &mut self, - ) -> Pin + Send + 'static>> { + ) -> Pin + 'static>> { self.clear_compute_pass(); enum TimestampMethod { @@ -301,14 +314,15 @@ impl ComputeServer for WgpuServer { type Storage = WgpuStorage; type Feature = Feature; - fn read(&mut self, binding: server::Binding) -> impl Future> + Send + 'static { + fn read(&mut self, binding: server::Binding) -> impl Future> + 'static { let rb = self.get_resource(binding); let resource = rb.resource(); self.clear_compute_pass(); + self.read_wgpu_buffer(&resource.buffer, resource.offset(), resource.size()) } - fn get_resource(&mut self, binding: server::Binding) -> BindingResource { + fn get_resource(&mut self, binding: server::Binding) -> BindingResource> { // Keep track of any buffer that might be used in the wgpu queue, as we cannot copy into them // after they have any outstanding compute work. Calling get_resource repeatedly // will add duplicates to this, but that is ok. @@ -376,7 +390,7 @@ impl ComputeServer for WgpuServer { unsafe fn execute( &mut self, - kernel: Self::Kernel, + kernel: as ComputeServer>::Kernel, count: CubeCount, bindings: Vec, mode: ExecutionMode, diff --git a/crates/cubecl-wgpu/src/compute/storage.rs b/crates/cubecl-wgpu/src/compute/storage.rs index 307d047f0..b4359893c 100644 --- a/crates/cubecl-wgpu/src/compute/storage.rs +++ b/crates/cubecl-wgpu/src/compute/storage.rs @@ -1,12 +1,14 @@ use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; use hashbrown::HashMap; -use std::{num::NonZeroU64, sync::Arc}; +use std::num::NonZeroU64; + +use crate::Pdrc; /// Buffer storage for wgpu. pub struct WgpuStorage { - memory: HashMap>, + memory: HashMap>, deallocations: Vec, - device: Arc, + device: Pdrc, } impl core::fmt::Debug for WgpuStorage { @@ -15,44 +17,10 @@ impl core::fmt::Debug for WgpuStorage { } } -/// The memory resource that can be allocated for wgpu. -#[derive(new)] -pub struct WgpuResource { - /// The wgpu buffer. - pub buffer: Arc, - - offset: u64, - size: u64, -} - -impl WgpuResource { - /// Return the binding view of the buffer. - pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource { - let binding = wgpu::BufferBinding { - buffer: &self.buffer, - offset: self.offset, - size: Some( - NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), - ), - }; - wgpu::BindingResource::Buffer(binding) - } - - /// Return the buffer size. - pub fn size(&self) -> u64 { - self.size - } - - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - self.offset - } -} - /// Keeps actual wgpu buffer references in a hashmap with ids as key. impl WgpuStorage { /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc) -> Self { + pub fn new(device: Pdrc) -> Self { Self { memory: HashMap::new(), deallocations: Vec::new(), @@ -86,7 +54,7 @@ impl ComputeStorage for WgpuStorage { fn alloc(&mut self, size: u64) -> StorageHandle { let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + let buffer = Pdrc::new(self.device.create_buffer(&wgpu::BufferDescriptor { label: None, size, usage: wgpu::BufferUsages::COPY_DST @@ -104,3 +72,36 @@ impl ComputeStorage for WgpuStorage { self.deallocations.push(id); } } + +/// The memory resource that can be allocated for wgpu. +#[derive(new)] +pub struct WgpuResource { + /// The wgpu buffer. + pub buffer: Pdrc, + offset: u64, + size: u64, +} + +impl WgpuResource { + /// Return the binding view of the buffer. + pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource { + let binding = wgpu::BufferBinding { + buffer: &self.buffer, + offset: self.offset, + size: Some( + NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), + ), + }; + wgpu::BindingResource::Buffer(binding) + } + + /// Return the buffer size. + pub fn size(&self) -> u64 { + self.size + } + + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + self.offset + } +} diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 4b9c444df..e1dcf676a 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -20,6 +20,18 @@ pub use runtime::*; #[cfg(feature = "spirv")] pub use compiler::spirv; +/// Platform dependent reference counting. Uses [`alloc::sync::Arc`] on all platforms except +/// `wasm32` when the feature `atomics` is enabled. Uses [`alloc::rc::Rc`] instead when on +/// `wasm32` and with the `atomics` feature enabled. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +type Pdrc = alloc::sync::Arc; + +/// Platform dependent reference counting. Uses [`alloc::sync::Arc`] on all platforms except +/// `wasm32` when the feature `atomics` is enabled. Uses [`alloc::rc::Rc`] instead when on +/// `wasm32` and with the `atomics` feature enabled. +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +type Pdrc = alloc::rc::Rc; + #[cfg(test)] mod tests { pub type TestRuntime = crate::WgpuRuntime; diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 224a9d108..584660588 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -1,19 +1,24 @@ use std::marker::PhantomData; +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +use std::{cell::RefCell, rc::Rc}; use crate::{ compiler::{base::WgpuCompiler, wgsl::WgslCompiler}, compute::{WgpuServer, WgpuStorage}, - AutoGraphicsApi, GraphicsApi, WgpuDevice, + AutoGraphicsApi, GraphicsApi, Pdrc, WgpuDevice, }; -use alloc::sync::Arc; -use cubecl_common::future; +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +use cubecl_core::future; +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +use cubecl_core::{channel::ComputeChannel, server::ComputeServer}; use cubecl_core::{Feature, Runtime}; -pub use cubecl_runtime::memory_management::MemoryConfiguration; -use cubecl_runtime::DeviceProperties; -use cubecl_runtime::{channel::MutexComputeChannel, client::ComputeClient, ComputeRuntime}; +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +use cubecl_runtime::channel::MutexComputeChannel; use cubecl_runtime::{ - memory_management::{MemoryDeviceProperties, MemoryManagement}, + client::ComputeClient, + memory_management::{MemoryConfiguration, MemoryDeviceProperties, MemoryManagement}, storage::ComputeStorage, + ComputeRuntime, DeviceProperties, }; use wgpu::RequestAdapterOptions; @@ -25,23 +30,78 @@ pub struct WgpuRuntime(PhantomData); type Server = WgpuServer; +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +thread_local! { + static LOCAL_DEVICE: RefCell>>> = RefCell::new(hashbrown::HashMap::default()); +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +static RUNTIME: ComputeRuntime = ComputeRuntime::new(); + /// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] static RUNTIME: ComputeRuntime> = ComputeRuntime::new(); impl Runtime for WgpuRuntime { type Compiler = WgslCompiler; - type Server = WgpuServer; + type Server = Server; - type Channel = MutexComputeChannel>; + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + type Channel = MutexComputeChannel; + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + type Channel = ThreadLocalChannel; type Device = WgpuDevice; fn client(device: &Self::Device) -> ComputeClient { RUNTIME.client(device, move || { - let setup = future::block_on(create_setup_for_device::( - device, - )); - create_client_on_setup(setup, RuntimeOptions::default()) + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + let setup = future::block_on(create_setup_for_device::< + AutoGraphicsApi, + WgslCompiler, + >(device)); + create_client_on_setup(setup, RuntimeOptions::default()) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let server = LOCAL_DEVICE.with_borrow_mut(|runtime| { + runtime + .get(device) + .expect(&format!("The wgpu server for {device:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread")) + .clone() + }); + let server = server.borrow(); + + let limits = server.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + + let features = server.device.features(); + let mut device_props = DeviceProperties::new(&[], mem_props); + + if features.contains(wgpu::Features::SUBGROUP) + && server.adapter.get_info().device_type != wgpu::DeviceType::Cpu + { + device_props.register_feature(Feature::Subcube); + } + ::register_features( + &server.adapter, + &server.device, + &mut device_props, + ); + + ComputeClient::new( + ThreadLocalChannel { + device: device.clone(), + }, + device_props, + ) + } }) } @@ -89,17 +149,47 @@ impl Default for RuntimeOptions { #[derive(Clone, Debug)] pub struct WgpuSetup { /// The underlying wgpu instance. - pub instance: Arc, + pub instance: Pdrc, /// The selected 'adapter'. This corresponds to a physical device. - pub adapter: Arc, + pub adapter: Pdrc, /// The wgpu device Burn will use. Nb: There can only be one device per adapter. - pub device: Arc, + pub device: Pdrc, /// The queue Burn commands will be submitted to. - pub queue: Arc, + pub queue: Pdrc, +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +pub async fn init_thread_server(device: WgpuDevice, options: RuntimeOptions) { + if !LOCAL_DEVICE.with_borrow(|map| map.contains_key(&device)) { + let setup = create_setup_for_device::(&device).await; + + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(setup.device.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let server = crate::compute::WgpuServer::new( + memory_management, + setup.device, + setup.queue, + setup.adapter, + options.tasks_max, + ); + + LOCAL_DEVICE.with_borrow_mut(|map| map.insert(device, Rc::new(RefCell::new(server)))); + } } /// Create a [`WgpuDevice`] on an existing [`WgpuSetup`]. /// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice { let device_id = WgpuDevice::Existing(setup.device.as_ref().global_id()); let client = create_client_on_setup(setup, options); @@ -109,6 +199,7 @@ pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice { /// Like [`init_setup_async`], but synchronous. /// On wasm, it is necessary to use [`init_setup_async`] instead. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub fn init_setup(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup { cfg_if::cfg_if! { if #[cfg(target_family = "wasm")] { @@ -123,6 +214,7 @@ pub fn init_setup(device: &WgpuDevice, options: RuntimeOptions) /// Initialize a client on the given device with the given options. /// This function is useful to configure the runtime options /// or to pick a different graphics API. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub async fn init_setup_async( device: &WgpuDevice, options: RuntimeOptions, @@ -134,6 +226,7 @@ pub async fn init_setup_async( return_setup } +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub(crate) fn create_client_on_setup( setup: WgpuSetup, options: RuntimeOptions, @@ -155,6 +248,7 @@ pub(crate) fn create_client_on_setup( memory_management, setup.device.clone(), setup.queue, + setup.adapter.clone(), options.tasks_max, ); let channel = MutexComputeChannel::new(server); @@ -185,10 +279,10 @@ pub(crate) async fn create_setup_for_device( ); WgpuSetup { - instance: Arc::new(instance), - adapter: Arc::new(adapter), - device: Arc::new(device), - queue: Arc::new(queue), + instance: Pdrc::new(instance), + adapter: Pdrc::new(adapter), + device: Pdrc::new(device), + queue: Pdrc::new(queue), } } @@ -338,3 +432,190 @@ fn get_device_override() -> Option { override_device }) } + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +#[derive(Debug, Clone)] +pub struct ThreadLocalChannel { + device: WgpuDevice, +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +impl ThreadLocalChannel { + fn make_server(device: &WgpuDevice) -> Rc> { + let setup = futures_lite::future::block_on(create_setup_for_device::< + AutoGraphicsApi, + WgslCompiler, + >(device)); + + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + + let options = RuntimeOptions::default(); + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(setup.device.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let server = crate::compute::WgpuServer::new( + memory_management, + setup.device, + setup.queue, + setup.adapter, + options.tasks_max, + ); + + Rc::new(RefCell::new(server)) + } +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +impl ComputeChannel for ThreadLocalChannel { + fn read( + &self, + binding: cubecl_core::server::Binding, + ) -> impl std::future::Future> { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .clone(); + async move { server.borrow_mut().read(binding).await } + }) + } + + fn get_resource( + &self, + binding: cubecl_core::server::Binding, + ) -> cubecl_runtime::storage::BindingResource { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + server.borrow_mut().get_resource(binding) + }) + } + + fn create(&self, data: &[u8]) -> cubecl_core::server::Handle { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + server.borrow_mut().create(data) + }) + } + + fn empty(&self, size: usize) -> cubecl_core::server::Handle { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + server.borrow_mut().empty(size) + }) + } + + unsafe fn execute( + &self, + kernel: ::Kernel, + count: cubecl_core::CubeCount, + bindings: Vec, + mode: cubecl_core::ExecutionMode, + ) { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + unsafe { server.borrow_mut().execute(kernel, count, bindings, mode) } + }) + } + + fn flush(&self) { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + server.borrow_mut().flush() + }) + } + + fn sync(&self) -> impl std::future::Future { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )) + .clone(); + async move { server.borrow_mut().sync().await } + }) + } + + fn sync_elapsed(&self) -> impl std::future::Future { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )) + .clone(); + async move { server.borrow_mut().sync_elapsed().await } + }) + } + + fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + server.borrow_mut().memory_usage() + }) + } + + fn enable_timestamps(&self) { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + server.borrow_mut().enable_timestamps() + }) + } + + fn disable_timestamps(&self) { + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); + server.borrow_mut().disable_timestamps() + }) + } +}