Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ all_tuples!(launch_tuple, 2, 12, T, t);
/// Expand type associated with a type.
#[derive(new)]
pub struct ExpandElementTyped<T: CubeType> {
pub(crate) expand: ExpandElement,
pub expand: ExpandElement,
pub(crate) _type: PhantomData<T>,
}

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ log = { workspace = true }
serde = { workspace = true }
sysinfo = { workspace = true }
tracel-llvm = { workspace = true }
paste = { workspace = true }

[dev-dependencies]
cubecl-core = { path = "../cubecl-core", version = "0.9.0-pre.1", features = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<'a> Visitor<'a> {
self.location,
));

let dim = self.get_index(dim, dim.ty);
let dim = self.get_index(dim, dim.ty, true);
let offset = self.append_operation_with_result(arith::addi(first_rank, dim, self.location));
let result = self.append_operation_with_result(memref::load(
metadata_memref,
Expand Down
31 changes: 23 additions & 8 deletions crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ impl<'a> Visitor<'a> {
}
Operator::CopyMemory(copy_memory) => {
let memref = self.get_memory(copy_memory.input);
let in_index = self.get_index(copy_memory.in_index, copy_memory.input.ty);
let in_index = self.get_index(
copy_memory.in_index,
copy_memory.input.ty,
copy_memory.input.ty.is_vectorized(),
);
let out_memref = self.get_memory(out);
let out_index = self.get_index(copy_memory.out_index, out.ty);
let out_index =
self.get_index(copy_memory.out_index, out.ty, out.ty.is_vectorized());
if out.ty.is_vectorized() {
let result = out.ty.to_type(self.context);
let value = self.append_operation_with_result(vector::load(
Expand Down Expand Up @@ -146,8 +151,7 @@ impl<'a> Visitor<'a> {

fn visit_index(&mut self, index: &IndexOperator, out: Variable) -> Value<'a, 'a> {
assert!(index.line_size == 0);
let mut index_value = self.get_index(index.index, out.ty);
let vector_type = index.list.ty.to_type(self.context);
let mut index_value = self.get_index(index.index, out.ty, index.list.ty.is_vectorized());
if !self.is_memory(index.list) {
let to_extract = self.get_variable(index.list);
// Item of size 1
Expand All @@ -167,6 +171,10 @@ impl<'a> Visitor<'a> {
llvm::extractelement(self.context, res, to_extract, index_value, self.location);
self.append_operation_with_result(vector_extract)
} else if out.ty.is_vectorized() {
let vector_type = Type::vector(
&[out.line_size() as u64],
index.list.storage_type().to_type(self.context),
);
let memref = self.get_memory(index.list);
self.append_operation_with_result(vector::load(
self.context,
Expand All @@ -189,7 +197,11 @@ impl<'a> Visitor<'a> {
out.kind,
VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. }
) {
let indices = self.get_index(index_assign.index, index_assign.value.ty);
let indices = self.get_index(
index_assign.index,
index_assign.value.ty,
out.ty.is_vectorized(),
);
let operation = if index_assign.value.ty.is_vectorized() {
vector::store(self.context, value, memref, &[indices], self.location).into()
} else {
Expand All @@ -199,14 +211,18 @@ impl<'a> Visitor<'a> {
return;
}
let operation = if index_assign.value.ty.is_vectorized() {
let indices = self.get_index(index_assign.index, index_assign.value.ty);
let indices = self.get_index(
index_assign.index,
index_assign.value.ty,
out.ty.is_vectorized(),
);
vector::store(self.context, value, memref, &[indices], self.location)
} else {
let vector_type = Type::vector(
&[out.line_size() as u64],
index_assign.value.storage_type().to_type(self.context),
);
let indices = self.get_index(index_assign.index, out.ty);
let indices = self.get_index(index_assign.index, out.ty, out.ty.is_vectorized());
let splat = self.append_operation_with_result(vector::splat(
self.context,
vector_type,
Expand Down Expand Up @@ -248,7 +264,6 @@ impl<'a> Visitor<'a> {
value,
)
};

self.insert_variable(out, value);
}

Expand Down
9 changes: 7 additions & 2 deletions crates/cubecl-cpu/src/compiler/visitor/variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,19 @@ impl<'a> Visitor<'a> {
}
}

pub fn get_index(&self, variable: Variable, target_item: ir::Type) -> Value<'a, 'a> {
pub fn get_index(
&self,
variable: Variable,
target_item: ir::Type,
list_is_vectorized: bool,
) -> Value<'a, 'a> {
let index = self.get_variable(variable);
let mut index = self.append_operation_with_result(index::casts(
index,
Type::index(self.context),
self.location,
));
if target_item.is_vectorized() {
if target_item.is_vectorized() && list_is_vectorized {
let vectorization = target_item.line_size() as i64;
let shift = vectorization.ilog2() as i64;
let constant = self.append_operation_with_result(arith::constant(
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-cpu/src/frontend/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod unaligned_line;
pub use unaligned_line::*;
83 changes: 83 additions & 0 deletions crates/cubecl-cpu/src/frontend/unaligned_line.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// use crate::frontend::unaligned_line::cubecl::intrinsic;
use cubecl_core::intrinsic;
use cubecl_core::ir::{IndexAssignOperator, IndexOperator, Instruction, Operator};
use cubecl_core::{self as cubecl, prelude::*};

#[cube]
pub trait UnalignedLine<E: CubePrimitive>:
CubeType<ExpandType = ExpandElementTyped<Self>> + Sized
{
fn unaligned_line_read(&self, index: u32, #[comptime] line_size: u32) -> Line<E>;

fn unaligned_line_write(&mut self, index: u32, value: Line<E>);
}

macro_rules! impl_unaligned_line {
($type:ident) => {
paste::paste! {
type [<$type Expand>]<E> = ExpandElementTyped<$type<E>>;
}
#[cube]
impl<E: CubePrimitive> UnalignedLine<E> for $type<E> {
fn unaligned_line_read(&self, index: u32, #[comptime] line_size: u32) -> Line<E> {
unaligned_line_read::<$type<E>, E>(self, index, line_size)
}

fn unaligned_line_write(&mut self, index: u32, value: Line<E>) {
unaligned_line_write::<$type<E>, E>(self, index, value)
}
}
};
}

Comment on lines 29 to 63
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is now user APIs I would add more docs here with maybe an example.

impl_unaligned_line!(Array);
impl_unaligned_line!(Tensor);
impl_unaligned_line!(SharedMemory);

#[cube]
#[allow(unused_variables)]
fn unaligned_line_read<T: CubeType<ExpandType = ExpandElementTyped<T>>, E: CubePrimitive>(
this: &T,
index: u32,
#[comptime] line_size: u32,
) -> Line<E> {
intrinsic!(|scope| {
if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) {
todo!("Unaligned reads are only allowed on scalar arrays for now");
}
let out = scope.create_local(this.expand.ty.line(line_size));
scope.register(Instruction::new(
Operator::UncheckedIndex(IndexOperator {
list: *this.expand,
index: index.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*out,
));
out.into()
})
}

#[cube]
#[allow(unused_variables)]
fn unaligned_line_write<T: CubeType<ExpandType = ExpandElementTyped<T>>, E: CubePrimitive>(
this: &mut T,
index: u32,
value: Line<E>,
) {
intrinsic!(|scope| {
if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) {
todo!("Unaligned reads are only allowed on scalar arrays for now");
}
scope.register(Instruction::new(
Operator::UncheckedIndexAssign(IndexAssignOperator {
index: index.expand.consume(),
value: value.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*this.expand,
));
})
}
1 change: 1 addition & 0 deletions crates/cubecl-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod tests {
pub mod compiler;
pub mod compute;
pub mod device;
pub mod frontend;
pub mod runtime;

pub use device::CpuDevice;
Expand Down
Loading