Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ all_tuples!(launch_tuple, 2, 12, T, t);
/// Expand type associated with a type.
#[derive(new)]
pub struct ExpandElementTyped<T: CubeType> {
pub(crate) expand: ExpandElement,
pub expand: ExpandElement,
pub(crate) _type: PhantomData<T>,
}

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ log = { workspace = true }
serde = { workspace = true }
sysinfo = { workspace = true }
tracel-llvm = { workspace = true }
paste = { workspace = true }

[dev-dependencies]
cubecl-core = { path = "../cubecl-core", version = "0.9.0-pre.1", features = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<'a> Visitor<'a> {
self.location,
));

let dim = self.get_index(dim, dim.ty);
let dim = self.get_index(dim, dim.ty, true);
let offset = self.append_operation_with_result(arith::addi(first_rank, dim, self.location));
let result = self.append_operation_with_result(memref::load(
metadata_memref,
Expand Down
31 changes: 23 additions & 8 deletions crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ impl<'a> Visitor<'a> {
}
Operator::CopyMemory(copy_memory) => {
let memref = self.get_memory(copy_memory.input);
let in_index = self.get_index(copy_memory.in_index, copy_memory.input.ty);
let in_index = self.get_index(
copy_memory.in_index,
copy_memory.input.ty,
copy_memory.input.ty.is_vectorized(),
);
let out_memref = self.get_memory(out);
let out_index = self.get_index(copy_memory.out_index, out.ty);
let out_index =
self.get_index(copy_memory.out_index, out.ty, out.ty.is_vectorized());
if out.ty.is_vectorized() {
let result = out.ty.to_type(self.context);
let value = self.append_operation_with_result(vector::load(
Expand Down Expand Up @@ -146,8 +151,7 @@ impl<'a> Visitor<'a> {

fn visit_index(&mut self, index: &IndexOperator, out: Variable) -> Value<'a, 'a> {
assert!(index.line_size == 0);
let mut index_value = self.get_index(index.index, out.ty);
let vector_type = index.list.ty.to_type(self.context);
let mut index_value = self.get_index(index.index, out.ty, index.list.ty.is_vectorized());
if !self.is_memory(index.list) {
let to_extract = self.get_variable(index.list);
// Item of size 1
Expand All @@ -167,6 +171,10 @@ impl<'a> Visitor<'a> {
llvm::extractelement(self.context, res, to_extract, index_value, self.location);
self.append_operation_with_result(vector_extract)
} else if out.ty.is_vectorized() {
let vector_type = Type::vector(
&[out.line_size() as u64],
index.list.storage_type().to_type(self.context),
);
let memref = self.get_memory(index.list);
self.append_operation_with_result(vector::load(
self.context,
Expand All @@ -189,7 +197,11 @@ impl<'a> Visitor<'a> {
out.kind,
VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. }
) {
let indices = self.get_index(index_assign.index, index_assign.value.ty);
let indices = self.get_index(
index_assign.index,
index_assign.value.ty,
out.ty.is_vectorized(),
);
let operation = if index_assign.value.ty.is_vectorized() {
vector::store(self.context, value, memref, &[indices], self.location).into()
} else {
Expand All @@ -199,14 +211,18 @@ impl<'a> Visitor<'a> {
return;
}
let operation = if index_assign.value.ty.is_vectorized() {
let indices = self.get_index(index_assign.index, index_assign.value.ty);
let indices = self.get_index(
index_assign.index,
index_assign.value.ty,
out.ty.is_vectorized(),
);
vector::store(self.context, value, memref, &[indices], self.location)
} else {
let vector_type = Type::vector(
&[out.line_size() as u64],
index_assign.value.storage_type().to_type(self.context),
);
let indices = self.get_index(index_assign.index, out.ty);
let indices = self.get_index(index_assign.index, out.ty, out.ty.is_vectorized());
let splat = self.append_operation_with_result(vector::splat(
self.context,
vector_type,
Expand Down Expand Up @@ -248,7 +264,6 @@ impl<'a> Visitor<'a> {
value,
)
};

self.insert_variable(out, value);
}

Expand Down
9 changes: 7 additions & 2 deletions crates/cubecl-cpu/src/compiler/visitor/variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,19 @@ impl<'a> Visitor<'a> {
}
}

pub fn get_index(&self, variable: Variable, target_item: ir::Type) -> Value<'a, 'a> {
pub fn get_index(
&self,
variable: Variable,
target_item: ir::Type,
list_is_vectorized: bool,
) -> Value<'a, 'a> {
let index = self.get_variable(variable);
let mut index = self.append_operation_with_result(index::casts(
index,
Type::index(self.context),
self.location,
));
if target_item.is_vectorized() {
if target_item.is_vectorized() && list_is_vectorized {
let vectorization = target_item.line_size() as i64;
let shift = vectorization.ilog2() as i64;
let constant = self.append_operation_with_result(arith::constant(
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-cpu/src/frontend/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod unaligned_line;
pub use unaligned_line::*;
119 changes: 119 additions & 0 deletions crates/cubecl-cpu/src/frontend/unaligned_line.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use cubecl_core::intrinsic;
use cubecl_core::ir::{IndexAssignOperator, IndexOperator, Instruction, Operator};
use cubecl_core::{self as cubecl, prelude::*};

/// An extension trait for expanding the cubecl frontend with the ability to
/// request unaligned line reads and writes
///
/// Typically in cubecl, a buffer is declared as having a certain line size
/// at kernel compilation time. The buffer can then be indexed to produce
/// lines that are aligned to the line_size.
///
/// This trait allows the user to request a line_read from a buffer where the
/// start of the read is not aligned to the line_read requested.
///
/// As an example, imagine a buffer of scalar length 4. With line_size = 1,
/// this could be illustrated like so
/// [1, 2, 3, 4]
///
/// Imagine the same buffer, now with line_size = 2.
/// [[1, 2], [3, 4]]
///
/// Lines can now be accessed from this buffer, but only those that that are aligned
/// with the line_size. I.e. we can get the lines [1, 2] or [3, 4], but not [2, 3]
///
/// This trait allows you to treat the buffer as having no line_size = 1, yet asking
/// for a line of some kernel-compile-time known length at some offset in the buffer.
/// I.e. if for the buffer `buf = [1, 2, 3, 4]`, `buf.unaligned_line_read(1, 2)`
/// will produce the line `[2, 3]`.
#[cube]
pub trait UnalignedLine<E: CubePrimitive>: CubeType + Sized {
/// Perform an unchecked read of a line of the given length at the given index
///
/// # Safety
/// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index..index+line_size is
/// always in bounds
fn unaligned_line_read(&self, index: u32, #[comptime] line_size: u32) -> Line<E>;

/// Perform an unchecked write of a line of the given length at the given index
///
/// # Safety
/// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index..index+line_size is
/// always in bounds
fn unaligned_line_write(&mut self, index: u32, value: Line<E>);
}

macro_rules! impl_unaligned_line {
($type:ident) => {
paste::paste! {
type [<$type Expand>]<E> = ExpandElementTyped<$type<E>>;
}
#[cube]
impl<E: CubePrimitive> UnalignedLine<E> for $type<E> {
fn unaligned_line_read(&self, index: u32, #[comptime] line_size: u32) -> Line<E> {
unaligned_line_read::<$type<E>, E>(self, index, line_size)
}

fn unaligned_line_write(&mut self, index: u32, value: Line<E>) {
unaligned_line_write::<$type<E>, E>(self, index, value)
}
}
};
}

impl_unaligned_line!(Array);
impl_unaligned_line!(Tensor);
impl_unaligned_line!(SharedMemory);

// TODO: Maybe impl unaligned IO on slices?
// The last dimension will have to be contiguous for this to make sense,
// as the unaligned IO isn't gather / scatter from arbitrary memory locations
// and still needs the loaded elements to be contiguous

#[cube]
#[allow(unused_variables)]
fn unaligned_line_read<T: CubeType<ExpandType = ExpandElementTyped<T>>, E: CubePrimitive>(
this: &T,
index: u32,
#[comptime] line_size: u32,
) -> Line<E> {
intrinsic!(|scope| {
if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) {
todo!("Unaligned reads are only allowed on scalar arrays for now");
}
let out = scope.create_local(this.expand.ty.line(line_size));
scope.register(Instruction::new(
Operator::UncheckedIndex(IndexOperator {
list: *this.expand,
index: index.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*out,
));
out.into()
})
}

#[cube]
#[allow(unused_variables)]
fn unaligned_line_write<T: CubeType<ExpandType = ExpandElementTyped<T>>, E: CubePrimitive>(
this: &mut T,
index: u32,
value: Line<E>,
) {
intrinsic!(|scope| {
if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) {
todo!("Unaligned reads are only allowed on scalar arrays for now");
}
scope.register(Instruction::new(
Operator::UncheckedIndexAssign(IndexAssignOperator {
index: index.expand.consume(),
value: value.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*this.expand,
));
})
}
1 change: 1 addition & 0 deletions crates/cubecl-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod tests {
pub mod compiler;
pub mod compute;
pub mod device;
pub mod frontend;
pub mod runtime;

pub use device::CpuDevice;
Expand Down
7 changes: 5 additions & 2 deletions crates/cubecl-cpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use cubecl_core::{
server::ServerUtilities,
};
use cubecl_runtime::{
DeviceProperties,
DeviceProperties, Features,
logging::ServerLogger,
memory_management::{
HardwareProperties, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions,
Expand Down Expand Up @@ -82,7 +82,10 @@ impl DeviceState for CpuServer {
);

let mut device_props = DeviceProperties::new(
Default::default(),
Features {
unaligned_io: true,
..Default::default()
},
mem_properties,
topology,
TimingMethod::Device,
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-runtime/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ pub struct Features {
pub ldmatrix: BTreeSet<StorageType>,
/// Types supported by stmatrix, if any
pub stmatrix: BTreeSet<StorageType>,
/// Whether Lines can be read from / stored to addresses not aligned
/// with the line_size
pub unaligned_io: bool,
}

/// Operations allowed for this type. CMMA is defined separately.
Expand Down
Loading