Skip to content
Closed
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ serde_json = { version = "1.0.119", default-features = false }
dashmap = "5.5.3"
hashbrown = "0.14.5"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
rayon = "1"

getrandom = { version = "0.2.15", default-features = false }
rand = { version = "0.8.5", default-features = false, features = [
Expand Down Expand Up @@ -73,6 +74,7 @@ pretty_assertions = "1.4"
# Async
embassy-futures = { version = "0.1.1" } # for no-std
futures-lite = { version = "2.3.0", default-features = false }
futures = "0.3.31"

[profile.dev]
opt-level = 2
1 change: 1 addition & 0 deletions crates/cubecl-runtime/src/memory_management/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::{format, string::String};
/// Amount of memory in use by this allocator
/// and statistics on how much memory is reserved and
/// wasted in total.
#[derive(Debug)]
pub struct MemoryUsage {
/// The number of allocations currently active.
pub number_allocs: u64,
Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-runtime/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ impl<K: AutotuneKey> TuneCache<K> {
} => {
if cfg!(autotune_persistent_cache) {
match checksum_matches {
None => TuneCacheResult::Unchecked, // Don't know yet.
Some(false) => TuneCacheResult::Miss, // Can't use this.
#[cfg(autotune_persistent_cache)]
None => TuneCacheResult::Unchecked, // Don't know yet.
Some(true) => TuneCacheResult::Hit {
fastest_index: *fastest_index,
},
_ => TuneCacheResult::Miss, // Some(false) or None so we can't use this.
}
} else {
let _ = checksum_matches;
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ web-time = { workspace = true }

cfg-if = { workspace = true }

[target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies]
futures = { workspace = true }
rayon = { workspace = true }

[dev-dependencies]
cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [
"export_tests",
Expand Down
10 changes: 4 additions & 6 deletions crates/cubecl-wgpu/src/compiler/base.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
use std::sync::Arc;

use cubecl_core::{
prelude::CompiledKernel, server::ComputeServer, Compiler, ExecutionMode, Feature,
};
use cubecl_runtime::DeviceProperties;
use wgpu::{Adapter, ComputePipeline, Device, Queue};

use crate::WgpuServer;
use crate::{Pdrc, WgpuServer, WgpuServerInner};

pub trait WgpuCompiler: Compiler {
fn compile(
server: &mut WgpuServer<Self>,
server: &mut WgpuServerInner<Self>,
kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
mode: ExecutionMode,
) -> CompiledKernel<Self>;

fn create_pipeline(
server: &mut WgpuServer<Self>,
server: &mut WgpuServerInner<Self>,
kernel: CompiledKernel<Self>,
mode: ExecutionMode,
) -> Arc<ComputePipeline>;
) -> Pdrc<ComputePipeline>;

#[allow(async_fn_in_trait)]
async fn request_device(adapter: &Adapter) -> (Device, Queue);
Expand Down
11 changes: 6 additions & 5 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::{borrow::Cow, sync::Arc};
use std::borrow::Cow;

use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory};
use super::{LocalArray, Subgroup};
use crate::{
compiler::{base::WgpuCompiler, wgsl},
WgpuServer,
};
use crate::{Pdrc, WgpuServerInner};
use cubecl_core::{
ir::{self as cube, HybridAllocator, UIntKind},
prelude::CompiledKernel,
Expand Down Expand Up @@ -70,10 +71,10 @@ impl cubecl_core::Compiler for WgslCompiler {

impl WgpuCompiler for WgslCompiler {
fn create_pipeline(
server: &mut WgpuServer<Self>,
server: &mut WgpuServerInner<Self>,
kernel: CompiledKernel<Self>,
mode: ExecutionMode,
) -> Arc<ComputePipeline> {
) -> Pdrc<ComputePipeline> {
let source = &kernel.source;
let repr = kernel.repr.unwrap();
let module = match mode {
Expand Down Expand Up @@ -118,7 +119,7 @@ impl WgpuCompiler for WgslCompiler {
push_constant_ranges: &[],
});

Arc::new(
Pdrc::new(
server
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
Expand All @@ -136,7 +137,7 @@ impl WgpuCompiler for WgslCompiler {
}

fn compile(
_server: &mut WgpuServer<Self>,
_server: &mut WgpuServerInner<Self>,
kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
mode: ExecutionMode,
) -> CompiledKernel<Self> {
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-wgpu/src/compute/poll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ mod _impl {
// On Wasm, the browser handles the polling loop, so we don't need anything.
#[cfg(target_family = "wasm")]
mod _impl {
use crate::Pdrc;

#[derive(Debug)]
pub struct WgpuPoll {}
impl WgpuPoll {
pub fn new(_device: alloc::sync::Arc<wgpu::Device>) -> Self {
pub fn new(_device: Pdrc<wgpu::Device>) -> Self {
Self {}
}
pub fn start_polling(&self) -> alloc::sync::Arc<()> {
Expand Down
Loading
Loading