diff --git a/crates/cubecl-cpu/Cargo.toml b/crates/cubecl-cpu/Cargo.toml index b68957ba5..5c8a148fa 100644 --- a/crates/cubecl-cpu/Cargo.toml +++ b/crates/cubecl-cpu/Cargo.toml @@ -113,5 +113,8 @@ cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [ cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [ "export_tests", ] } +cubecl-scan = { path = "../cubecl-scan", version = "0.7.0", features = [ + "export_tests", +] } paste = { workspace = true } pretty_assertions = { workspace = true } diff --git a/crates/cubecl-cpu/src/lib.rs b/crates/cubecl-cpu/src/lib.rs index 497a1f169..18b477dc3 100644 --- a/crates/cubecl-cpu/src/lib.rs +++ b/crates/cubecl-cpu/src/lib.rs @@ -18,6 +18,7 @@ mod tests { cubecl_matmul::testgen_matmul_unit!(); cubecl_convolution::testgen_conv2d_accelerated!([f16: f16, f32: f32]); cubecl_reduce::testgen_shared_sum!([f16, f32, f64]); + cubecl_scan::testgen_scan!(); // Deactivated for now as it makes the CI hang // cubecl_reduce::testgen_reduce!([f16, f32, f64]); diff --git a/crates/cubecl-cuda/Cargo.toml b/crates/cubecl-cuda/Cargo.toml index 5504945bb..13afe0a49 100644 --- a/crates/cubecl-cuda/Cargo.toml +++ b/crates/cubecl-cuda/Cargo.toml @@ -124,5 +124,8 @@ cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", features = [ cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [ "export_tests", ] } +cubecl-scan = { path = "../cubecl-scan", version = "0.7.0", features = [ + "export_tests", +] } paste = { workspace = true } pretty_assertions = { workspace = true } diff --git a/crates/cubecl-cuda/src/lib.rs b/crates/cubecl-cuda/src/lib.rs index 79d6a41b3..87e67b7ad 100644 --- a/crates/cubecl-cuda/src/lib.rs +++ b/crates/cubecl-cuda/src/lib.rs @@ -39,4 +39,5 @@ mod tests { cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]); cubecl_random::testgen_random!(); cubecl_reduce::testgen_shared_sum!([f16, bf16, f32, f64]); + cubecl_scan::testgen_scan!(); } diff --git a/crates/cubecl-scan/Cargo.toml b/crates/cubecl-scan/Cargo.toml new file mode 100644 index 000000000..316aa8d4b --- /dev/null +++ b/crates/cubecl-scan/Cargo.toml @@ -0,0 +1,26 @@ +[package] +authors = ["Friedrich Schirmer"] +categories = ["science", "mathematics", "algorithms"] +description = "CubeCL Scan Algorithms." +edition.workspace = true +keywords = [] +license.workspace = true +name = "cubecl-scan" +readme.workspace = true +repository = "https://github.com/tracel-ai/cubecl/tree/main/crates/cubecl-scan" +version.workspace = true + +[features] +default = ["std", "cubecl-runtime/default", "cubecl-core/default"] +export_tests = ["pretty_assertions", "rand"] +std = ["cubecl-runtime/std", "cubecl-core/std"] + +[dependencies] +cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false } +num-traits = "0.2.19" +pretty_assertions = { workspace = true, optional = true } +rand = { workspace = true, optional = true } +serde = { workspace = true } +half = { workspace = true } diff --git a/crates/cubecl-scan/src/base.rs b/crates/cubecl-scan/src/base.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/crates/cubecl-scan/src/base.rs @@ -0,0 +1 @@ + diff --git a/crates/cubecl-scan/src/config.rs b/crates/cubecl-scan/src/config.rs new file mode 100644 index 000000000..90c3219a6 --- /dev/null +++ b/crates/cubecl-scan/src/config.rs @@ -0,0 +1,42 @@ +use cubecl_core::{ + CubeCount, CubeDim, Runtime, + client::ComputeClient, + prelude::{CubePrimitive, TensorHandleRef}, +}; + +#[derive(Debug, Clone)] +pub struct ScanConfig { + pub cube_count: CubeCount, + pub cube_dim: CubeDim, + pub clear_cube_dim: CubeDim, + pub line_size: u32, + pub clear_line_size: u32, + pub inclusive: bool, +} + +impl ScanConfig { + pub(crate) fn generate( + client: &ComputeClient, + input: &TensorHandleRef, + output: &TensorHandleRef, + ) -> ScanConfig { + // ToDo + ScanConfig::empty() + } + + fn empty() -> Self { + Self { + cube_count: CubeCount::new_single(), + cube_dim: CubeDim::new_single(), + clear_cube_dim: CubeDim::new_single(), + line_size: 1, + clear_line_size: 1, + inclusive: false, + } + } + + pub fn with_inclusive(mut self, inclusive: bool) -> Self { + self.inclusive = inclusive; + self + } +} diff --git a/crates/cubecl-scan/src/error.rs b/crates/cubecl-scan/src/error.rs new file mode 100644 index 000000000..1602a0041 --- /dev/null +++ b/crates/cubecl-scan/src/error.rs @@ -0,0 +1,47 @@ +use core::fmt; +use cubecl_core::ir::StorageType; + +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub enum ScanError { + /// Indicate that the hardware / API doesn't support SIMT plane instructions. + PlanesUnavailable, + /// When the cube count is bigger than the max supported. + CubeCountTooLarge, + /// Indicate that min_plane_dim != max_plane_dim, thus the exact plane_dim is not fixed. + ImprecisePlaneDim, + MismatchSize { + shape_a: Vec, + shape_b: Vec, + }, + /// Indicates that the buffer type is not supported by the backend. + UnsupportedType(StorageType), + /// Indicates that we can't launch a decoupled look-back scan + /// because the atomic load/store operations are not supported. + MissingAtomicLoadStore(StorageType), +} + +impl fmt::Display for ScanError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::PlanesUnavailable => write!( + f, + "Trying to launch a kernel using plane instructions, but there are not supported by the hardware." + ), + Self::CubeCountTooLarge => write!(f, "The cube count is larger than the max supported"), + Self::ImprecisePlaneDim => write!( + f, + "Trying to launch a kernel using plane instructions, but the min and max plane dimensions are different." + ), + Self::MismatchSize { shape_a, shape_b } => write!( + f, + "The tensor of shape {shape_a:?} should have the same number of elements as the one with shape {shape_b:?}." + ), + Self::UnsupportedType(ty) => { + write!(f, "The type {ty} is not supported by the client") + } + Self::MissingAtomicLoadStore(ty) => { + write!(f, "Atomic load/store not supported by the client for {ty}") + } + } + } +} diff --git a/crates/cubecl-scan/src/instructions/add.rs b/crates/cubecl-scan/src/instructions/add.rs new file mode 100644 index 000000000..2d9c87293 --- /dev/null +++ b/crates/cubecl-scan/src/instructions/add.rs @@ -0,0 +1,55 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::instructions::ScanInstruction; + +#[derive(Debug, CubeType, Clone)] +pub struct Add {} + +#[cube] +impl ScanInstruction for Add { + fn aggregate_line(line: Line, #[comptime] line_size: u32) -> N { + let mut sum = N::cast_from(0); + #[unroll] + for i in 0..line_size { + sum += line[i]; + } + sum + } + + fn scan_line( + mut base: N, + line: Line, + #[comptime] line_size: u32, + #[comptime] inclusive: bool, + ) -> Line { + let mut res = Line::empty(line_size); + + #[unroll] + for i in 0..line_size { + if !inclusive { + res[i] = base; + } + + base += line[i]; + + if inclusive { + res[i] = base + } + } + + res + } + + fn scan_plane(val: N, #[comptime] inclusive: bool) -> N { + if inclusive { + plane_inclusive_sum(val) + } else { + plane_exclusive_sum(val) + } + } + + fn apply(a: N, b: N) -> N { + a + b + } +} diff --git a/crates/cubecl-scan/src/instructions/base.rs b/crates/cubecl-scan/src/instructions/base.rs new file mode 100644 index 000000000..365339bb3 --- /dev/null +++ b/crates/cubecl-scan/src/instructions/base.rs @@ -0,0 +1,18 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +#[cube] +pub trait ScanInstruction: 'static + Send + Sync + std::fmt::Debug + CubeType { + fn aggregate_line(line: Line, #[comptime] line_size: u32) -> N; + + fn scan_line( + base: N, + line: Line, + #[comptime] line_size: u32, + #[comptime] inclusive: bool, + ) -> Line; + + fn scan_plane(val: N, #[comptime] inclusive: bool) -> N; + + fn apply(a: N, b: N) -> N; +} diff --git a/crates/cubecl-scan/src/instructions/mod.rs b/crates/cubecl-scan/src/instructions/mod.rs new file mode 100644 index 000000000..a9b45a138 --- /dev/null +++ b/crates/cubecl-scan/src/instructions/mod.rs @@ -0,0 +1,5 @@ +mod add; +mod base; + +pub use add::*; +pub use base::*; diff --git a/crates/cubecl-scan/src/kernels/decoupled_lookback.rs b/crates/cubecl-scan/src/kernels/decoupled_lookback.rs new file mode 100644 index 000000000..e83aa2451 --- /dev/null +++ b/crates/cubecl-scan/src/kernels/decoupled_lookback.rs @@ -0,0 +1,249 @@ +use cubecl::prelude::*; +use cubecl_core as cubecl; +use cubecl_std::tensor::TensorHandle; + +use crate::{ScanError, instructions::ScanInstruction}; + +// NOTE: this is only a simple scan for now (no support for tensor dimensions) + +type Flag = u32; + +const FLAG_AGGREGATE_AVAILABLE: Flag = 1; +const FLAG_PREFIX_AVAILABLE: Flag = 2; + +#[cube] +fn to_u32(val: N, #[comptime] size: u32) -> u32 { + match size { + 1 => u32::cast_from(u8::reinterpret(val)), + 2 => u32::cast_from(u16::reinterpret(val)), + 4 => u32::reinterpret(val), + _ => panic!("Unsupported size {size}"), + } +} + +#[cube] +fn from_u32(val: u32, #[comptime] size: u32) -> N { + match size { + 1 => N::reinterpret(u8::cast_from(val)), + 2 => N::reinterpret(u16::cast_from(val)), + 4 => N::reinterpret(val), + _ => panic!("Unsupported size {size}"), + } +} + +#[cube(launch, launch_unchecked)] +pub fn decoupled_lookback_scan_kernel( + input: &Array>, + aggregates: &mut Array>, + flags: &mut Array>, + output: &mut Array>, + #[comptime] line_size: u32, + #[comptime] inclusive: bool, + #[comptime] elem_size: u32, +) { + let partition_idx = CUBE_POS; + let val = select( + ABSOLUTE_POS < input.len(), + input[ABSOLUTE_POS], + Line::::empty(line_size), + ); + + let local_aggregate = I::aggregate_line::(val, line_size); + let plane_scan = I::scan_plane::(local_aggregate, false); + let aggregate = plane_scan + local_aggregate; + + // Perform the aggregate broadcasting step + let aggregate_idx = partition_idx * 2; + if ABSOLUTE_POS == PLANE_DIM - 1 { + // Handle the first partition + // Aggregate + Atomic::store( + &aggregates[aggregate_idx], + to_u32::(aggregate, elem_size), + ); + // Prefix + Atomic::store( + &aggregates[aggregate_idx + 1], + to_u32::(aggregate, elem_size), + ); + + sync_storage(); + // Mark the prefix as available + Atomic::store(&flags[partition_idx], FLAG_PREFIX_AVAILABLE); + } else if UNIT_POS == PLANE_DIM - 1 { + // Handle all other partitions + Atomic::store( + &aggregates[aggregate_idx], + to_u32::(aggregate, elem_size), + ); + + sync_storage(); + // Mark the aggregate as available + Atomic::store(&flags[partition_idx], FLAG_AGGREGATE_AVAILABLE); + } + sync_cube(); + sync_storage(); + + let mut lookback_idx = partition_idx; + let mut done: u32 = 0; + let mut lookback_value = N::cast_from(0); + while lookback_idx > 0 && done == 0 { + if UNIT_POS == 0 { + let desc_idx = lookback_idx - 1; + let pred_flag = Atomic::load(&flags[desc_idx]); + + if pred_flag == FLAG_AGGREGATE_AVAILABLE { + let aggregate = Atomic::load(&aggregates[desc_idx * 2]); + lookback_value = I::apply::(lookback_value, from_u32::(aggregate, elem_size)); + lookback_idx -= 1; + } else if pred_flag == FLAG_PREFIX_AVAILABLE { + let aggregate = Atomic::load(&aggregates[desc_idx * 2 + 1]); + lookback_value = I::apply::(lookback_value, from_u32::(aggregate, elem_size)); + done = 1; + } + } + sync_cube(); + + // Broadcast the "done" state to all threads in the plane + done = plane_broadcast(done, 0); + } + sync_cube(); + // Fetch the computed lookback value into all threads in the plane + lookback_value = plane_broadcast(lookback_value, 0); + + let scan_carry = lookback_value + aggregate; + + // Mark the prefix as available + if UNIT_POS == PLANE_DIM - 1 { + // Prefix + Atomic::store( + &aggregates[aggregate_idx + 1], + to_u32::(scan_carry, elem_size), + ); + Atomic::store(&flags[partition_idx], FLAG_PREFIX_AVAILABLE); + } + sync_cube(); + sync_storage(); + + let scan_res = I::scan_line::(lookback_value + plane_scan, val, line_size, inclusive); + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = scan_res; + } +} + +pub fn launch_ref( + client: &ComputeClient, + input: &TensorHandleRef<'_, R>, + output: &TensorHandleRef<'_, R>, + axis: usize, + inclusive: bool, +) -> Result<(), ScanError> { + let input = TensorHandle::::from_ref(input); + let output = TensorHandle::::from_ref(output); + + launch::(client, input, output, axis, inclusive) +} + +pub fn launch( + client: &ComputeClient, + input: TensorHandle, + output: TensorHandle, + axis: usize, + inclusive: bool, +) -> Result<(), ScanError> { + use cubecl_core::Feature; + let atomic_elem = Atomic::::as_type_native_unchecked(); + let has_feature = |f| client.properties().feature_enabled(f); + + // Check that the client supports the provided type + if !has_feature(Feature::Type(N::as_type_native_unchecked())) { + return Err(ScanError::UnsupportedType(N::as_type_native_unchecked())); + } + + // Check that the client supports atomic load/store + if !has_feature(Feature::Type(atomic_elem)) + || !has_feature(Feature::AtomicUInt(cubecl_core::AtomicFeature::LoadStore)) + { + return Err(ScanError::MissingAtomicLoadStore( + Flag::as_type_native_unchecked(), + )); + } + + let num_elements = tensor_size(&input); + if num_elements != tensor_size(&output) { + return Err(ScanError::MismatchSize { + shape_a: input.shape.clone(), + shape_b: output.shape.clone(), + }); + } + if !precise_plane_dim::(client) { + return Err(ScanError::ImprecisePlaneDim); + } + + let hw_props = &client.properties().hardware; + let plane_size = hw_props.plane_size_max; + + // ToDo: do better line size selection for non-1 strides + let line_size = match input.strides[axis] { + 1 => { + let elem = N::as_type_native_unchecked(); + R::line_size_type(&elem) + .filter(|s| num_elements % (*s as usize) == 0) + .max() + .unwrap_or(1) as u32 + } + _ => 1, + }; + + let cube_dim = CubeDim::new_1d(plane_size); + + let block_elements = (line_size * plane_size) as usize; + let num_blocks = num_elements.div_ceil(block_elements); + + // Round to the granularity used by the reset kernel + let flags_per_cube = { + let elem = Flag::as_type_native_unchecked(); + let line_size = R::line_size_type(&elem).max().unwrap_or(0) as usize; + line_size * (plane_size as usize) + }; + let num_flags = num_blocks.next_multiple_of(flags_per_cube); + + // Overwritten before any reads, so it can contain garbage initially + let aggregates = client.empty(num_blocks * 2 * (u32::elem_size() as usize)); + let flags = TensorHandle::::zeros(client, vec![num_flags]); + + dbg!(input.shape); + dbg!(output.shape); + dbg!(num_elements); + dbg!(line_size); + dbg!(plane_size); + dbg!(block_elements); + dbg!(num_blocks); + dbg!(num_flags); + + unsafe { + decoupled_lookback_scan_kernel::launch::( + client, + CubeCount::Static(num_blocks as u32, 1, 1), + cube_dim, + ArrayArg::from_raw_parts::(&input.handle, num_elements, line_size as u8), + ArrayArg::from_raw_parts::(&aggregates, num_blocks * 2, 1), + ArrayArg::from_raw_parts::(&flags.handle, num_flags, 1), + ArrayArg::from_raw_parts::(&output.handle, num_elements, line_size as u8), + line_size, + inclusive, + N::elem_size(), + ); + } + + Ok(()) +} + +fn tensor_size(handle: &TensorHandle) -> usize { + handle.shape.iter().product::() +} + +fn precise_plane_dim(client: &ComputeClient) -> bool { + let hw_props = &client.properties().hardware; + hw_props.plane_size_min == hw_props.plane_size_max +} diff --git a/crates/cubecl-scan/src/kernels/mod.rs b/crates/cubecl-scan/src/kernels/mod.rs new file mode 100644 index 000000000..20c77ebca --- /dev/null +++ b/crates/cubecl-scan/src/kernels/mod.rs @@ -0,0 +1 @@ +pub mod decoupled_lookback; diff --git a/crates/cubecl-scan/src/lib.rs b/crates/cubecl-scan/src/lib.rs new file mode 100644 index 000000000..86b79f78b --- /dev/null +++ b/crates/cubecl-scan/src/lib.rs @@ -0,0 +1,46 @@ +//! This provides different implementations of the associative scan algorithm +//! which can run on multiple GPU backends using CubeCL. +//! +//! The commonly known prefix sum or cumsum operation is an associative scan +//! using the associative addition operator. In general, an associative scan +//! is a (parallel) scan operation with an operator that is required to be +//! associative with the following property: +//! * Let the input sequence of numbers be `x_0`, `x_1`, `x_2`, ... +//! * Let the output sequence of numbers by `y_0`, `y_1`, `y_2`, ... +//! * The output is now defined as `y_0 = x_0`, `y_1 = x_0 + x_1`, +//! `y_2 = x_0 + x_1 + x_2`, ... + +mod base; +mod config; +mod error; +pub mod instructions; +pub mod kernels; + +#[cfg(feature = "export_tests")] +pub mod tests; + +pub use base::*; +pub use config::*; +pub use error::*; + +use crate::instructions::ScanInstruction; +use cubecl_core::prelude::*; + +// ToDo: add algorithm reference to the book +// ToDo: write benchmarks (the algorithm should be almost 100% memory bound ideally) +// ToDo: abstract the algorithm selection into a strategy enum like in matmul + +pub fn associative_scan( + client: &ComputeClient, + input: TensorHandleRef<'_, R>, + output: TensorHandleRef<'_, R>, + axis: usize, + inclusive: bool, +) -> Result<(), ScanError> { + // ToDo: maybe allocate the secondary storage using client here + // ToDo: at least in CUDA, the kernels launched here should execute sequentially => important for the 3-stage impl + + kernels::decoupled_lookback::launch_ref::(client, &input, &output, axis, inclusive)?; + + Ok(()) +} diff --git a/crates/cubecl-scan/src/tests/mod.rs b/crates/cubecl-scan/src/tests/mod.rs new file mode 100644 index 000000000..1eef5d85c --- /dev/null +++ b/crates/cubecl-scan/src/tests/mod.rs @@ -0,0 +1,13 @@ +pub mod simple; + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_scan { + () => { + mod test_scan { + use super::*; + + cubecl_scan::testgen_scan_simple!(); + } + }; +} diff --git a/crates/cubecl-scan/src/tests/simple.rs b/crates/cubecl-scan/src/tests/simple.rs new file mode 100644 index 000000000..47de16714 --- /dev/null +++ b/crates/cubecl-scan/src/tests/simple.rs @@ -0,0 +1,108 @@ +use crate::{associative_scan, instructions::ScanInstruction}; +use cubecl::prelude::*; +use cubecl_core as cubecl; +use cubecl_std::tensor::TensorHandle; +use rand::{Rng, SeedableRng, distr::Uniform}; + +#[macro_export] +macro_rules! testgen_scan_simple { + () => { + mod scan_simple { + use super::*; + use cubecl_scan::*; + + $crate::testgen_scan_simple!(@group: [ty=[i8, u8, i16, u16, i32, u32], sizes=[1, 10, 100, 128, 256, 1_000, 4097, 1_000_000]]: { + Add: |a, b| a.wrapping_add(b); + }); + $crate::testgen_scan_simple!(@group: [ty=[f32], sizes=[1, 10, 100, 128, 256, 1_000, 4097]]: { + Add: |a, b| a + b; + }); + } + }; + (@group: [ty=[$($ty:ty),*], sizes=$sizes:expr]: $rest:tt) => { + $( + $crate::testgen_scan_simple!(@group: [ty=$ty, sizes=$sizes, false]: $rest); + $crate::testgen_scan_simple!(@group: [ty=$ty, sizes=$sizes, true]: $rest); + )* + }; + (@group: [ty=$ty:ty, sizes=$sizes:expr, $inclusive:literal]: { + $( $instr:ty : $op:expr ; )* + }) => { + $( + paste::paste! { + #[test] + fn []() { + for size in $sizes { + let client = TestRuntime::client(&Default::default()); + let test = cubecl_scan::tests::simple::TestCase { + shape: vec![size], + stride: vec![1], + axis: 0, + inclusive: $inclusive, + }; + test.test_scan::(&client, $op); + } + } + } + )* + }; +} + +#[derive(Debug)] +pub struct TestCase { + pub shape: Vec, + pub stride: Vec, + pub axis: usize, + pub inclusive: bool, +} + +impl TestCase { + pub fn test_scan( + &self, + client: &ComputeClient, + op: impl Fn(N, N) -> N, + ) { + let len = self.shape.iter().product::(); + let data = rand::rngs::StdRng::seed_from_u64(1234) + .sample_iter(Uniform::::new(1, 20).unwrap()) + .take(len) + .map(|v| N::from_int(v)) + .collect::>(); + let expected = self.reference_scan(&data, N::from_int(0), op); + + let handle = client.create(N::as_bytes(&data)); + let input = TensorHandle::::new(handle, vec![len], self.stride.clone()); + let output = TensorHandle::::empty(client, vec![len]); + + let res = associative_scan::( + &client, + input.as_ref(), + output.as_ref(), + 0, + self.inclusive, + ); + + if res.is_err() { + return; + } + + let output_data = client.read_one(output.handle); + let output_data = &N::from_bytes(&output_data)[..len]; + + assert_eq!(&expected[..], output_data); + } + + fn reference_scan(&self, data: &[T], start: T, op: impl Fn(T, T) -> T) -> Vec { + data.iter() + .scan(start, |acc, v| { + let mut res = *acc; + *acc = op(*acc, *v); + if self.inclusive { + res = *acc; + } + + Some(res) + }) + .collect() + } +} diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index f7f258b31..872e3d536 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -154,6 +154,9 @@ cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [ cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [ "export_tests", ] } +cubecl-scan = { path = "../cubecl-scan", version = "0.7.0", features = [ + "export_tests", +] } half = { workspace = true } paste = { workspace = true } pretty_assertions = { workspace = true } diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 608888530..c78b89780 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -41,6 +41,7 @@ mod tests { cubecl_attention::testgen_attention!(); cubecl_reduce::testgen_shared_sum!([f32]); cubecl_quant::testgen_quant!(); + cubecl_scan::testgen_scan!(); } #[cfg(all(test, feature = "spirv"))] @@ -62,6 +63,7 @@ mod tests_spirv { cubecl_random::testgen_random!(); cubecl_reduce::testgen_shared_sum!([f32]); cubecl_quant::testgen_quant!(); + cubecl_scan::testgen_scan!(); } #[cfg(all(test, feature = "msl"))] @@ -82,4 +84,5 @@ mod tests_msl { cubecl_reduce::testgen_reduce!(); cubecl_random::testgen_random!(); cubecl_reduce::testgen_shared_sum!([f32]); + cubecl_scan::testgen_scan!(); } diff --git a/crates/cubecl/Cargo.toml b/crates/cubecl/Cargo.toml index 2b9dc2695..e5ffd8f85 100644 --- a/crates/cubecl/Cargo.toml +++ b/crates/cubecl/Cargo.toml @@ -28,6 +28,7 @@ matmul = ["dep:cubecl-matmul"] convolution = ["dep:cubecl-convolution"] reduce = ["dep:cubecl-reduce"] random = ["dep:cubecl-random"] +scan = ["dep:cubecl-scan"] std = ["cubecl-core/std", "cubecl-wgpu?/std", "cubecl-cuda?/std"] stdlib = ["cubecl-std"] # CubeCL standard library template = ["cubecl-core/template"] @@ -58,6 +59,7 @@ cubecl-hip = { path = "../cubecl-hip", version = "0.7.0", default-features = fal cubecl-convolution = { path = "../cubecl-convolution", version = "0.7.0", default-features = false, optional = true } cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", default-features = false, optional = true } cubecl-random = { path = "../cubecl-random", version = "0.7.0", default-features = false, optional = true } +cubecl-scan = { path = "../cubecl-scan", version = "0.7.0", default-features = false, optional = true } cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", default-features = false, optional = true } cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } cubecl-std = { path = "../cubecl-std", version = "0.7.0", optional = true } diff --git a/crates/cubecl/src/lib.rs b/crates/cubecl/src/lib.rs index 26551166d..b9f6cd95b 100644 --- a/crates/cubecl/src/lib.rs +++ b/crates/cubecl/src/lib.rs @@ -27,5 +27,8 @@ pub use cubecl_reduce as reduce; #[cfg(feature = "random")] pub use cubecl_random as random; +#[cfg(feature = "scan")] +pub use cubecl_scan as scan; + #[cfg(feature = "cpu")] pub use cubecl_cpu as cpu;