Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/cubecl-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,8 @@ cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [
cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [
"export_tests",
] }
cubecl-scan = { path = "../cubecl-scan", version = "0.7.0", features = [
"export_tests",
] }
paste = { workspace = true }
pretty_assertions = { workspace = true }
1 change: 1 addition & 0 deletions crates/cubecl-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod tests {
cubecl_matmul::testgen_matmul_unit!();
cubecl_convolution::testgen_conv2d_accelerated!([f16: f16, f32: f32]);
cubecl_reduce::testgen_shared_sum!([f16, f32, f64]);
cubecl_scan::testgen_scan!();

// Deactivated for now as it makes the CI hang
// cubecl_reduce::testgen_reduce!([f16, f32, f64]);
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,8 @@ cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", features = [
cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [
"export_tests",
] }
cubecl-scan = { path = "../cubecl-scan", version = "0.7.0", features = [
"export_tests",
] }
paste = { workspace = true }
pretty_assertions = { workspace = true }
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ mod tests {
cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]);
cubecl_random::testgen_random!();
cubecl_reduce::testgen_shared_sum!([f16, bf16, f32, f64]);
cubecl_scan::testgen_scan!();
}
26 changes: 26 additions & 0 deletions crates/cubecl-scan/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
authors = ["Friedrich Schirmer"]
categories = ["science", "mathematics", "algorithms"]
description = "CubeCL Scan Algorithms."
edition.workspace = true
keywords = []
license.workspace = true
name = "cubecl-scan"
readme.workspace = true
repository = "https://github.com/tracel-ai/cubecl/tree/main/crates/cubecl-scan"
version.workspace = true

[features]
default = ["std", "cubecl-runtime/default", "cubecl-core/default"]
export_tests = ["pretty_assertions", "rand"]
std = ["cubecl-runtime/std", "cubecl-core/std"]

[dependencies]
cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false }
cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false }
cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false }
num-traits = "0.2.19"
pretty_assertions = { workspace = true, optional = true }
rand = { workspace = true, optional = true }
serde = { workspace = true }
half = { workspace = true }
1 change: 1 addition & 0 deletions crates/cubecl-scan/src/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

42 changes: 42 additions & 0 deletions crates/cubecl-scan/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use cubecl_core::{
CubeCount, CubeDim, Runtime,
client::ComputeClient,
prelude::{CubePrimitive, TensorHandleRef},
};

#[derive(Debug, Clone)]
pub struct ScanConfig {
pub cube_count: CubeCount,
pub cube_dim: CubeDim,
pub clear_cube_dim: CubeDim,
pub line_size: u32,
pub clear_line_size: u32,
pub inclusive: bool,
}

impl ScanConfig {
pub(crate) fn generate<R: Runtime, N: CubePrimitive>(
client: &ComputeClient<R::Server, R::Channel>,
input: &TensorHandleRef<R>,
output: &TensorHandleRef<R>,
) -> ScanConfig {
// ToDo
ScanConfig::empty()
}

fn empty() -> Self {
Self {
cube_count: CubeCount::new_single(),
cube_dim: CubeDim::new_single(),
clear_cube_dim: CubeDim::new_single(),
line_size: 1,
clear_line_size: 1,
inclusive: false,
}
}

pub fn with_inclusive(mut self, inclusive: bool) -> Self {
self.inclusive = inclusive;
self
}
}
47 changes: 47 additions & 0 deletions crates/cubecl-scan/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use core::fmt;
use cubecl_core::ir::StorageType;

#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub enum ScanError {
/// Indicate that the hardware / API doesn't support SIMT plane instructions.
PlanesUnavailable,
/// When the cube count is bigger than the max supported.
CubeCountTooLarge,
/// Indicate that min_plane_dim != max_plane_dim, thus the exact plane_dim is not fixed.
ImprecisePlaneDim,
MismatchSize {
shape_a: Vec<usize>,
shape_b: Vec<usize>,
},
/// Indicates that the buffer type is not supported by the backend.
UnsupportedType(StorageType),
/// Indicates that we can't launch a decoupled look-back scan
/// because the atomic load/store operations are not supported.
MissingAtomicLoadStore(StorageType),
}

impl fmt::Display for ScanError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::PlanesUnavailable => write!(
f,
"Trying to launch a kernel using plane instructions, but there are not supported by the hardware."
),
Self::CubeCountTooLarge => write!(f, "The cube count is larger than the max supported"),
Self::ImprecisePlaneDim => write!(
f,
"Trying to launch a kernel using plane instructions, but the min and max plane dimensions are different."
),
Self::MismatchSize { shape_a, shape_b } => write!(
f,
"The tensor of shape {shape_a:?} should have the same number of elements as the one with shape {shape_b:?}."
),
Self::UnsupportedType(ty) => {
write!(f, "The type {ty} is not supported by the client")
}
Self::MissingAtomicLoadStore(ty) => {
write!(f, "Atomic load/store not supported by the client for {ty}")
}
}
}
}
55 changes: 55 additions & 0 deletions crates/cubecl-scan/src/instructions/add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::instructions::ScanInstruction;

#[derive(Debug, CubeType, Clone)]
pub struct Add {}

#[cube]
impl ScanInstruction for Add {
fn aggregate_line<N: Numeric>(line: Line<N>, #[comptime] line_size: u32) -> N {
let mut sum = N::cast_from(0);
#[unroll]
for i in 0..line_size {
sum += line[i];
}
sum
}

fn scan_line<N: Numeric>(
mut base: N,
line: Line<N>,
#[comptime] line_size: u32,
#[comptime] inclusive: bool,
) -> Line<N> {
let mut res = Line::empty(line_size);

#[unroll]
for i in 0..line_size {
if !inclusive {
res[i] = base;
}

base += line[i];

if inclusive {
res[i] = base
}
}

res
}

fn scan_plane<N: Numeric>(val: N, #[comptime] inclusive: bool) -> N {
if inclusive {
plane_inclusive_sum(val)
} else {
plane_exclusive_sum(val)
}
}

fn apply<N: Numeric>(a: N, b: N) -> N {
a + b
}
}
18 changes: 18 additions & 0 deletions crates/cubecl-scan/src/instructions/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

#[cube]
pub trait ScanInstruction: 'static + Send + Sync + std::fmt::Debug + CubeType {
fn aggregate_line<N: Numeric>(line: Line<N>, #[comptime] line_size: u32) -> N;

fn scan_line<N: Numeric>(
base: N,
line: Line<N>,
#[comptime] line_size: u32,
#[comptime] inclusive: bool,
) -> Line<N>;

fn scan_plane<N: Numeric>(val: N, #[comptime] inclusive: bool) -> N;

fn apply<N: Numeric>(a: N, b: N) -> N;
}
5 changes: 5 additions & 0 deletions crates/cubecl-scan/src/instructions/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod add;
mod base;

pub use add::*;
pub use base::*;
Loading
Loading