diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index 9e981ff32..7b0d840cb 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -167,7 +167,7 @@ all_tuples!(launch_tuple, 2, 12, T, t); /// Expand type associated with a type. #[derive(new)] pub struct ExpandElementTyped { - pub(crate) expand: ExpandElement, + pub expand: ExpandElement, pub(crate) _type: PhantomData, } diff --git a/crates/cubecl-cpu/Cargo.toml b/crates/cubecl-cpu/Cargo.toml index 7a9e2c37d..13ef3c3c2 100644 --- a/crates/cubecl-cpu/Cargo.toml +++ b/crates/cubecl-cpu/Cargo.toml @@ -97,6 +97,7 @@ log = { workspace = true } serde = { workspace = true } sysinfo = { workspace = true } tracel-llvm = { workspace = true } +paste = { workspace = true } [dev-dependencies] cubecl-core = { path = "../cubecl-core", version = "0.9.0-pre.1", features = [ diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/metadata.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/metadata.rs index 040f38359..1d3f2fb2c 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/metadata.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/metadata.rs @@ -45,7 +45,7 @@ impl<'a> Visitor<'a> { self.location, )); - let dim = self.get_index(dim, dim.ty); + let dim = self.get_index(dim, dim.ty, true); let offset = self.append_operation_with_result(arith::addi(first_rank, dim, self.location)); let result = self.append_operation_with_result(memref::load( metadata_memref, diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs index a375976c3..409065e6b 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs @@ -23,9 +23,14 @@ impl<'a> Visitor<'a> { } Operator::CopyMemory(copy_memory) => { let memref = self.get_memory(copy_memory.input); - let in_index = self.get_index(copy_memory.in_index, copy_memory.input.ty); + let in_index = self.get_index( + copy_memory.in_index, + copy_memory.input.ty, + copy_memory.input.ty.is_vectorized(), + ); let out_memref = self.get_memory(out); - let out_index = self.get_index(copy_memory.out_index, out.ty); + let out_index = + self.get_index(copy_memory.out_index, out.ty, out.ty.is_vectorized()); if out.ty.is_vectorized() { let result = out.ty.to_type(self.context); let value = self.append_operation_with_result(vector::load( @@ -146,8 +151,7 @@ impl<'a> Visitor<'a> { fn visit_index(&mut self, index: &IndexOperator, out: Variable) -> Value<'a, 'a> { assert!(index.line_size == 0); - let mut index_value = self.get_index(index.index, out.ty); - let vector_type = index.list.ty.to_type(self.context); + let mut index_value = self.get_index(index.index, out.ty, index.list.ty.is_vectorized()); if !self.is_memory(index.list) { let to_extract = self.get_variable(index.list); // Item of size 1 @@ -167,6 +171,10 @@ impl<'a> Visitor<'a> { llvm::extractelement(self.context, res, to_extract, index_value, self.location); self.append_operation_with_result(vector_extract) } else if out.ty.is_vectorized() { + let vector_type = Type::vector( + &[out.line_size() as u64], + index.list.storage_type().to_type(self.context), + ); let memref = self.get_memory(index.list); self.append_operation_with_result(vector::load( self.context, @@ -189,7 +197,11 @@ impl<'a> Visitor<'a> { out.kind, VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } ) { - let indices = self.get_index(index_assign.index, index_assign.value.ty); + let indices = self.get_index( + index_assign.index, + index_assign.value.ty, + out.ty.is_vectorized(), + ); let operation = if index_assign.value.ty.is_vectorized() { vector::store(self.context, value, memref, &[indices], self.location).into() } else { @@ -199,14 +211,18 @@ impl<'a> Visitor<'a> { return; } let operation = if index_assign.value.ty.is_vectorized() { - let indices = self.get_index(index_assign.index, index_assign.value.ty); + let indices = self.get_index( + index_assign.index, + index_assign.value.ty, + out.ty.is_vectorized(), + ); vector::store(self.context, value, memref, &[indices], self.location) } else { let vector_type = Type::vector( &[out.line_size() as u64], index_assign.value.storage_type().to_type(self.context), ); - let indices = self.get_index(index_assign.index, out.ty); + let indices = self.get_index(index_assign.index, out.ty, out.ty.is_vectorized()); let splat = self.append_operation_with_result(vector::splat( self.context, vector_type, @@ -248,7 +264,6 @@ impl<'a> Visitor<'a> { value, ) }; - self.insert_variable(out, value); } diff --git a/crates/cubecl-cpu/src/compiler/visitor/variables.rs b/crates/cubecl-cpu/src/compiler/visitor/variables.rs index c4bde8a44..e1955d363 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/variables.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/variables.rs @@ -328,14 +328,19 @@ impl<'a> Visitor<'a> { } } - pub fn get_index(&self, variable: Variable, target_item: ir::Type) -> Value<'a, 'a> { + pub fn get_index( + &self, + variable: Variable, + target_item: ir::Type, + list_is_vectorized: bool, + ) -> Value<'a, 'a> { let index = self.get_variable(variable); let mut index = self.append_operation_with_result(index::casts( index, Type::index(self.context), self.location, )); - if target_item.is_vectorized() { + if target_item.is_vectorized() && list_is_vectorized { let vectorization = target_item.line_size() as i64; let shift = vectorization.ilog2() as i64; let constant = self.append_operation_with_result(arith::constant( diff --git a/crates/cubecl-cpu/src/frontend/mod.rs b/crates/cubecl-cpu/src/frontend/mod.rs new file mode 100644 index 000000000..ab7e6290f --- /dev/null +++ b/crates/cubecl-cpu/src/frontend/mod.rs @@ -0,0 +1,2 @@ +mod unaligned_line; +pub use unaligned_line::*; diff --git a/crates/cubecl-cpu/src/frontend/unaligned_line.rs b/crates/cubecl-cpu/src/frontend/unaligned_line.rs new file mode 100644 index 000000000..35430be00 --- /dev/null +++ b/crates/cubecl-cpu/src/frontend/unaligned_line.rs @@ -0,0 +1,119 @@ +use cubecl_core::intrinsic; +use cubecl_core::ir::{IndexAssignOperator, IndexOperator, Instruction, Operator}; +use cubecl_core::{self as cubecl, prelude::*}; + +/// An extension trait for expanding the cubecl frontend with the ability to +/// request unaligned line reads and writes +/// +/// Typically in cubecl, a buffer is declared as having a certain line size +/// at kernel compilation time. The buffer can then be indexed to produce +/// lines that are aligned to the line_size. +/// +/// This trait allows the user to request a line_read from a buffer where the +/// start of the read is not aligned to the line_read requested. +/// +/// As an example, imagine a buffer of scalar length 4. With line_size = 1, +/// this could be illustrated like so +/// [1, 2, 3, 4] +/// +/// Imagine the same buffer, now with line_size = 2. +/// [[1, 2], [3, 4]] +/// +/// Lines can now be accessed from this buffer, but only those that that are aligned +/// with the line_size. I.e. we can get the lines [1, 2] or [3, 4], but not [2, 3] +/// +/// This trait allows you to treat the buffer as having no line_size = 1, yet asking +/// for a line of some kernel-compile-time known length at some offset in the buffer. +/// I.e. if for the buffer `buf = [1, 2, 3, 4]`, `buf.unaligned_line_read(1, 2)` +/// will produce the line `[2, 3]`. +#[cube] +pub trait UnalignedLine: CubeType + Sized { + /// Perform an unchecked read of a line of the given length at the given index + /// + /// # Safety + /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index..index+line_size is + /// always in bounds + fn unaligned_line_read(&self, index: u32, #[comptime] line_size: u32) -> Line; + + /// Perform an unchecked write of a line of the given length at the given index + /// + /// # Safety + /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index..index+line_size is + /// always in bounds + fn unaligned_line_write(&mut self, index: u32, value: Line); +} + +macro_rules! impl_unaligned_line { + ($type:ident) => { + paste::paste! { + type [<$type Expand>] = ExpandElementTyped<$type>; + } + #[cube] + impl UnalignedLine for $type { + fn unaligned_line_read(&self, index: u32, #[comptime] line_size: u32) -> Line { + unaligned_line_read::<$type, E>(self, index, line_size) + } + + fn unaligned_line_write(&mut self, index: u32, value: Line) { + unaligned_line_write::<$type, E>(self, index, value) + } + } + }; +} + +impl_unaligned_line!(Array); +impl_unaligned_line!(Tensor); +impl_unaligned_line!(SharedMemory); + +// TODO: Maybe impl unaligned IO on slices? +// The last dimension will have to be contiguous for this to make sense, +// as the unaligned IO isn't gather / scatter from arbitrary memory locations +// and still needs the loaded elements to be contiguous + +#[cube] +#[allow(unused_variables)] +fn unaligned_line_read>, E: CubePrimitive>( + this: &T, + index: u32, + #[comptime] line_size: u32, +) -> Line { + intrinsic!(|scope| { + if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) { + todo!("Unaligned reads are only allowed on scalar arrays for now"); + } + let out = scope.create_local(this.expand.ty.line(line_size)); + scope.register(Instruction::new( + Operator::UncheckedIndex(IndexOperator { + list: *this.expand, + index: index.expand.consume(), + line_size: 0, + unroll_factor: 1, + }), + *out, + )); + out.into() + }) +} + +#[cube] +#[allow(unused_variables)] +fn unaligned_line_write>, E: CubePrimitive>( + this: &mut T, + index: u32, + value: Line, +) { + intrinsic!(|scope| { + if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) { + todo!("Unaligned reads are only allowed on scalar arrays for now"); + } + scope.register(Instruction::new( + Operator::UncheckedIndexAssign(IndexAssignOperator { + index: index.expand.consume(), + value: value.expand.consume(), + line_size: 0, + unroll_factor: 1, + }), + *this.expand, + )); + }) +} diff --git a/crates/cubecl-cpu/src/lib.rs b/crates/cubecl-cpu/src/lib.rs index 0a96f2f57..a7c9def4f 100644 --- a/crates/cubecl-cpu/src/lib.rs +++ b/crates/cubecl-cpu/src/lib.rs @@ -24,6 +24,7 @@ mod tests { pub mod compiler; pub mod compute; pub mod device; +pub mod frontend; pub mod runtime; pub use device::CpuDevice; diff --git a/crates/cubecl-cpu/src/runtime.rs b/crates/cubecl-cpu/src/runtime.rs index 31a68b164..56d6f2ab4 100644 --- a/crates/cubecl-cpu/src/runtime.rs +++ b/crates/cubecl-cpu/src/runtime.rs @@ -6,7 +6,7 @@ use cubecl_core::{ server::ServerUtilities, }; use cubecl_runtime::{ - DeviceProperties, + DeviceProperties, Features, logging::ServerLogger, memory_management::{ HardwareProperties, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions, @@ -82,7 +82,10 @@ impl DeviceState for CpuServer { ); let mut device_props = DeviceProperties::new( - Default::default(), + Features { + unaligned_io: true, + ..Default::default() + }, mem_properties, topology, TimingMethod::Device, diff --git a/crates/cubecl-runtime/src/features.rs b/crates/cubecl-runtime/src/features.rs index 6e97c7c51..ba85d9717 100644 --- a/crates/cubecl-runtime/src/features.rs +++ b/crates/cubecl-runtime/src/features.rs @@ -36,6 +36,9 @@ pub struct Features { pub ldmatrix: BTreeSet, /// Types supported by stmatrix, if any pub stmatrix: BTreeSet, + /// Whether Lines can be read from / stored to addresses not aligned + /// with the line_size + pub unaligned_io: bool, } /// Operations allowed for this type. CMMA is defined separately.