diff --git a/crates/cubecl-attention/src/base.rs b/crates/cubecl-attention/src/base.rs index 6dbf67a49..bf47c05bc 100644 --- a/crates/cubecl-attention/src/base.rs +++ b/crates/cubecl-attention/src/base.rs @@ -28,6 +28,7 @@ pub fn launch( query: TensorHandle>, key: TensorHandle>, value: TensorHandle>, + mask: Option>>, out: TensorHandle>, ) -> Result<(), AttentionSetupError> { launch_ref::( @@ -36,6 +37,7 @@ pub fn launch( &query.as_ref(), &key.as_ref(), &value.as_ref(), + &mask.as_ref().map(|m| m.as_ref()), &out.as_ref(), ) } @@ -47,10 +49,11 @@ pub fn launch_ref( query: &TensorHandleRef, key: &TensorHandleRef, value: &TensorHandleRef, + mask: &Option>, out: &TensorHandleRef, ) -> Result<(), AttentionSetupError> { match strategy { - Strategy::Tmp => launch_tmp::(client, query, key, value, out), + Strategy::Tmp => launch_tmp::(client, query, key, value, mask, out), } } @@ -59,6 +62,7 @@ pub fn launch_tmp( query: &TensorHandleRef, key: &TensorHandleRef, value: &TensorHandleRef, + mask: &Option>, out: &TensorHandleRef, ) -> Result<(), AttentionSetupError> { let line_sizes = AvailableLineSizes::from_elem_types::( @@ -81,7 +85,8 @@ pub fn launch_tmp( num_heads: query.shape[2], head_dim: query.shape[3], val_dim: value.shape[3], - masked: false, + masked: mask.is_some(), + causal: false, }; let tile_size = AttentionTileSize { @@ -123,6 +128,9 @@ pub fn launch_tmp( query.as_tensor_arg(line_sizes.query), key.as_tensor_arg(line_sizes.key), value.as_tensor_arg(line_sizes.value), + mask.as_ref() + .map(|it| it.as_tensor_arg(line_sizes.out)) + .into(), ), out.as_tensor_arg(line_sizes.out), cube_count_plan.as_args(), diff --git a/crates/cubecl-attention/src/components/args.rs b/crates/cubecl-attention/src/components/args.rs index 822018608..982ef2792 100644 --- a/crates/cubecl-attention/src/components/args.rs +++ b/crates/cubecl-attention/src/components/args.rs @@ -1,7 +1,7 @@ use cubecl::prelude::*; use cubecl_core::{self as cubecl}; use cubecl_std::{ - CubeOption, CubeOptionExpand, + CubeOption, CubeOptionArgs, CubeOptionExpand, tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand}, }; @@ -16,7 +16,7 @@ pub trait ConcreteInputsFactory: LaunchArg { query: &'a TensorHandleRef<'a, R>, key: &'a TensorHandleRef<'a, R>, value: &'a TensorHandleRef<'a, R>, - // mask: &'a TensorHandleRef<'a, R>, + mask: &'a Option>, selection: &AttentionSelection, problem: &AttentionProblem, line_sizes: &AttentionLineSizes, @@ -38,206 +38,278 @@ pub trait ConcreteOutputFactory: LaunchArg { /// Arguments for the attention algorithm. pub trait AttentionArgs: Send + Sync + 'static + Clone { /// Type used for the input. - type Input: LaunchArg + CubeType; + type Input: LaunchArg + CubeType; /// Type used for the output. type Output: LaunchArg + CubeType; /// Inner state that is used to create [tensor inputs](TensorInput) and /// [tensor outputs](TensorOutput) . - type State: CubeType; + type State: CubeType; /// Init the state. - fn init_state( - input: &Self::Input, + fn init_state( + input: &Self::Input, output: &mut Self::Output, - ) -> Self::State; + ) -> Self::State; + + /// Whether the mask argument is present. Returns `CubeOption` to allow matching at + /// comptime + fn has_mask( + state: &Self::State, + ) -> CubeOption<()>; /// Read the line of the query tensor using the state at the given coordinate. - fn read_query( - state: &Self::State, + fn read_query( + state: &Self::State, coordinate: u32, ) -> Line; /// Read the line of the key tensor using the state at the given coordinate. - fn read_key( - state: &Self::State, + fn read_key( + state: &Self::State, coordinate: u32, ) -> Line; /// Read the line of the value tensor using the state at the given coordinate. - fn read_value( - state: &Self::State, + fn read_value( + state: &Self::State, coordinate: u32, ) -> Line; + /// Read the line of the mask tensor using the state at the given coordinate. + fn read_mask( + state: &Self::State, + coordinate: u32, + ) -> Line; /// Read the line of the query tensor using the state at the given coordinate. - fn read_window_query( - state: &Self::State, + fn read_window_query( + state: &Self::State, start: u32, end: u32, ) -> Slice>; - /// Read the line of the key tensor using the state at the given coordinate. - fn read_window_key( - state: &Self::State, + fn read_window_key( + state: &Self::State, start: u32, end: u32, ) -> Slice>; - /// Read the line of the value tensor using the state at the given coordinate. - fn read_window_value( - state: &Self::State, + fn read_window_value( + state: &Self::State, start: u32, end: u32, ) -> Slice>; + /// Read the line of the mask tensor using the state at the given coordinate. + fn read_window_mask( + state: &Self::State, + start: u32, + end: u32, + ) -> Slice>; /// Reinterpret query as tensor map - fn as_tensor_map_query( - state: &Self::State, + fn as_tensor_map_query( + state: &Self::State, ) -> CubeOption>; - /// Reinterpret key as tensor map - fn as_tensor_map_key( - state: &Self::State, + fn as_tensor_map_key( + state: &Self::State, ) -> CubeOption>; - /// Reinterpret value as tensor map - fn as_tensor_map_value( - state: &Self::State, + fn as_tensor_map_value( + state: &Self::State, ) -> CubeOption>; + /// Reinterpret mask as tensor map + fn as_tensor_map_mask( + state: &Self::State, + ) -> CubeOption>; /// Write the line to the output at the given coordinate using the state. - fn write_out( - state: &mut Self::State, + fn write_out( + state: &mut Self::State, coordinate: u32, value: Line, ); /// Get the rank of the query tensor using the state. - fn rank_query(state: &Self::State) -> u32; + fn rank_query( + state: &Self::State, + ) -> u32; /// Get the rank of the key tensor using the state. - fn rank_key(state: &Self::State) -> u32; + fn rank_key( + state: &Self::State, + ) -> u32; /// Get the rank of the value tensor using the state. - fn rank_value(state: &Self::State) -> u32; + fn rank_value( + state: &Self::State, + ) -> u32; + /// Get the rank of the mask tensor using the state. + fn rank_mask( + state: &Self::State, + ) -> u32; /// Get the rank of the out tensor using the state. - fn rank_out(state: &Self::State) -> u32; + fn rank_out( + state: &Self::State, + ) -> u32; /// Get the length of the query tensor using the state. - fn len_query(state: &Self::State) -> u32; + fn len_query( + state: &Self::State, + ) -> u32; /// Get the length of the key tensor using the state. - fn len_key(state: &Self::State) -> u32; + fn len_key( + state: &Self::State, + ) -> u32; /// Get the length of the value tensor using the state. - fn len_value(state: &Self::State) -> u32; + fn len_value( + state: &Self::State, + ) -> u32; + /// Get the length of the mask tensor using the state. + fn len_mask( + state: &Self::State, + ) -> u32; /// Get the length of the out tensor using the state. - fn len_out(state: &Self::State) -> u32; + fn len_out( + state: &Self::State, + ) -> u32; /// Get the buffer length of the query tensor using the state. - fn buffer_len_query( - state: &Self::State, + fn buffer_len_query( + state: &Self::State, ) -> u32; /// Get the buffer length of the key tensor using the state. - fn buffer_len_key( - state: &Self::State, + fn buffer_len_key( + state: &Self::State, ) -> u32; /// Get the buffer length of the value tensor using the state. - fn buffer_len_value( - state: &Self::State, + fn buffer_len_value( + state: &Self::State, + ) -> u32; + /// Get the buffer length of the mask tensor using the state. + fn buffer_len_mask( + state: &Self::State, ) -> u32; /// Get the buffer length of the out tensor using the state. - fn buffer_len_out( - state: &Self::State, + fn buffer_len_out( + state: &Self::State, ) -> u32; /// Get the shape of the query tensor using the state. - fn shape_query( - state: &Self::State, + fn shape_query( + state: &Self::State, axis: u32, ) -> u32; /// Get the shape of the key tensor using the state. - fn shape_key( - state: &Self::State, + fn shape_key( + state: &Self::State, axis: u32, ) -> u32; /// Get the shape of the value tensor using the state. - fn shape_value( - state: &Self::State, + fn shape_value( + state: &Self::State, + axis: u32, + ) -> u32; + /// Get the shape of the mask tensor using the state. + fn shape_mask( + state: &Self::State, axis: u32, ) -> u32; /// Get the shape of the out tensor using the state. - fn shape_out( - state: &Self::State, + fn shape_out( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the query tensor using the state. - fn stride_query( - state: &Self::State, + fn stride_query( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the key tensor using the state. - fn stride_key( - state: &Self::State, + fn stride_key( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the value tensor using the state. - fn stride_value( - state: &Self::State, + fn stride_value( + state: &Self::State, + axis: u32, + ) -> u32; + /// Get the stride of the mask tensor using the state. + fn stride_mask( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the out tensor using the state. - fn stride_out( - state: &Self::State, + fn stride_out( + state: &Self::State, axis: u32, ) -> u32; - fn line_size_query( - state: &Self::State, + /// Get the line size of the query tensor using the state. + fn line_size_query( + state: &Self::State, + ) -> comptime_type!(u32); + /// Get the line size of the key tensor using the state. + fn line_size_key( + state: &Self::State, ) -> comptime_type!(u32); - fn line_size_key( - state: &Self::State, + /// Get the line size of the value tensor using the state. + fn line_size_value( + state: &Self::State, ) -> comptime_type!(u32); - fn line_size_value( - state: &Self::State, + /// Get the line size of the mask tensor using the state. + fn line_size_mask( + state: &Self::State, ) -> comptime_type!(u32); - fn line_size_out( - state: &Self::State, + /// Get the line size of the out tensor using the state. + fn line_size_out( + state: &Self::State, ) -> comptime_type!(u32); } /// Tensor input representation. /// /// You can use the tensor input as if it was a pointer to the actually tensor. -pub struct TensorQuery { - state: *const GA::State, +pub struct TensorQuery { + state: *const GA::State, } -pub struct TensorKey { - state: *const GA::State, +pub struct TensorKey { + state: *const GA::State, } -pub struct TensorValue { - state: *const GA::State, +pub struct TensorValue { + state: *const GA::State, } -impl VirtualTensorOperations - for TensorQuery +pub struct TensorMask { + state: *const GA::State, +} + +impl + VirtualTensorOperations for TensorQuery { } -impl VirtualTensorOperations - for TensorKey +impl + VirtualTensorOperations for TensorKey { } -impl VirtualTensorOperations - for TensorValue +impl + VirtualTensorOperations for TensorValue { } -impl VirtualTensorOperations - for TensorOutput +impl + VirtualTensorOperations for TensorMask { } -impl VirtualTensorOperationsExpand - for TensorOutputExpand +impl + VirtualTensorOperations for TensorOutput +{ +} + +impl + VirtualTensorOperationsExpand for TensorOutputExpand { fn __expand_read_method( &self, @@ -298,12 +370,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorOutput +impl Lined + for TensorOutput { } -impl LinedExpand - for TensorOutputExpand +impl LinedExpand + for TensorOutputExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -311,8 +383,8 @@ impl LinedExpand } } -impl VirtualTensorOperationsExpand - for TensorQueryExpand +impl + VirtualTensorOperationsExpand for TensorQueryExpand { fn __expand_read_method( &self, @@ -372,12 +444,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorQuery +impl Lined + for TensorQuery { } -impl LinedExpand - for TensorQueryExpand +impl LinedExpand + for TensorQueryExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -385,8 +457,8 @@ impl LinedExpand } } -impl VirtualTensorOperationsExpand - for TensorKeyExpand +impl + VirtualTensorOperationsExpand for TensorKeyExpand { fn __expand_read_method( &self, @@ -446,12 +518,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorKey +impl Lined + for TensorKey { } -impl LinedExpand - for TensorKeyExpand +impl LinedExpand + for TensorKeyExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -459,8 +531,8 @@ impl LinedExpand } } -impl VirtualTensorOperationsExpand - for TensorValueExpand +impl + VirtualTensorOperationsExpand for TensorValueExpand { fn __expand_read_method( &self, @@ -520,12 +592,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorValue +impl Lined + for TensorValue { } -impl LinedExpand - for TensorValueExpand +impl LinedExpand + for TensorValueExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -533,40 +605,124 @@ impl LinedExpand } } +impl + VirtualTensorOperationsExpand for TensorMaskExpand +{ + fn __expand_read_method( + &self, + scope: &mut Scope, + index: ExpandElementTyped, + ) -> ExpandElementTyped> { + TensorMaskExpand::__expand_read_method(self.clone(), scope, index) + } + fn __expand_read_window_method( + &self, + context: &mut Scope, + start: ExpandElementTyped, + end: ExpandElementTyped, + ) -> SliceExpand, ReadOnly> { + TensorMaskExpand::__expand_read_window_method(self.clone(), context, start, end) + } + + fn __expand_write_method( + &self, + _scope: &mut Scope, + _index: ExpandElementTyped, + _value: ExpandElementTyped>, + ) { + panic!("Can't write to input tensor"); + } + + fn __expand_shape_method( + &self, + scope: &mut Scope, + axis: ExpandElementTyped, + ) -> ExpandElementTyped { + TensorMaskExpand::__expand_shape_method(self.clone(), scope, axis) + } + + fn __expand_stride_method( + &self, + scope: &mut Scope, + axis: ExpandElementTyped, + ) -> ExpandElementTyped { + TensorMaskExpand::__expand_stride_method(self.clone(), scope, axis) + } + + fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped { + TensorMaskExpand::__expand_rank_method(self.clone(), scope) + } + + fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { + TensorMaskExpand::__expand_len_method(self.clone(), scope) + } + + fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { + TensorMaskExpand::__expand_buffer_len_method(self.clone(), scope) + } + + fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand> { + TensorMaskExpand::__expand_as_tensor_map_method(self.clone(), scope) + } +} + +impl Lined + for TensorMask +{ +} +impl LinedExpand + for TensorMaskExpand +{ + fn line_size(&self) -> u32 { + let mut scope = Scope::root(false); + TensorMaskExpand::__expand_line_size_method(self.clone(), &mut scope) + } +} + /// Tensor output representation. /// /// You can use the tensor output as if it was a pointer to the actually tensor. /// /// # Warning +/// # Warning /// /// There is no mutability guarantee. -pub struct TensorOutput { - state: *mut GA::State, +pub struct TensorOutput { + state: *mut GA::State, } /// Expand type for [tensor input](TensorInput). -pub struct TensorQueryExpand { - state: as CubeType>::ExpandType, +pub struct TensorQueryExpand +{ + state: as CubeType>::ExpandType, } -pub struct TensorKeyExpand { - state: as CubeType>::ExpandType, +pub struct TensorKeyExpand { + state: as CubeType>::ExpandType, } -pub struct TensorValueExpand { - state: as CubeType>::ExpandType, +pub struct TensorValueExpand +{ + state: as CubeType>::ExpandType, +} + +pub struct TensorMaskExpand { + state: as CubeType>::ExpandType, } /// Expand type for [tensor output](TensorOutput). -pub struct TensorOutputExpand { - state: as CubeType>::ExpandType, +pub struct TensorOutputExpand +{ + state: as CubeType>::ExpandType, } #[cube] -impl TensorQuery { +impl + TensorQuery +{ /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorQuery { - TensorQuery:: { state } + pub fn new(state: &MA::State) -> TensorQuery { + TensorQuery:: { state } } //// Read the tensor at the given coordinate. @@ -617,10 +773,12 @@ impl TensorQuery TensorKey { +impl + TensorKey +{ /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorKey { - TensorKey:: { state } + pub fn new(state: &MA::State) -> TensorKey { + TensorKey:: { state } } //// Read the tensor at the given coordinate. @@ -671,10 +829,12 @@ impl TensorKey TensorValue { +impl + TensorValue +{ /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorValue { - TensorValue:: { state } + pub fn new(state: &MA::State) -> TensorValue { + TensorValue:: { state } } //// Read the tensor at the given coordinate. @@ -725,10 +885,68 @@ impl TensorValue TensorOutput { +impl + TensorMask +{ + /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). + pub fn new(state: &MA::State) -> TensorMask { + TensorMask:: { state } + } + + //// Read the tensor at the given coordinate. + pub fn read_window(&self, start: u32, end: u32) -> Slice> { + unsafe { MA::read_window_mask(&(*self.state), start, end) } + } + + /// Read the tensor at the given coordinate. + pub fn read(&self, coordinate: u32) -> Line { + unsafe { MA::read_mask(&(*self.state), coordinate) } + } + + /// Get the shape of the tensor at the given axis. + pub fn shape(&self, axis: u32) -> u32 { + unsafe { MA::shape_mask(&(*self.state), axis) } + } + + /// Get the stride of the tensor at the given axis. + pub fn stride(&self, axis: u32) -> u32 { + unsafe { MA::stride_mask(&(*self.state), axis) } + } + + /// Get the rank of the tensor. + pub fn rank(&self) -> u32 { + unsafe { MA::rank_mask(&(*self.state)) } + } + + /// Get the length of the tensor. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { + unsafe { MA::len_mask(&(*self.state)) } + } + + /// Get the buffer length of the tensor. + pub fn buffer_len(&self) -> u32 { + unsafe { MA::buffer_len_mask(&(*self.state)) } + } + + /// Get the buffer length of the tensor. + pub fn as_tensor_map(&self) -> CubeOption> { + unsafe { MA::as_tensor_map_mask(&(*self.state)) } + } + + /// Get the line size of the tensor. + pub fn line_size(&self) -> comptime_type!(u32) { + unsafe { MA::line_size_mask(&(*self.state)) } + } +} + +#[cube] +impl + TensorOutput +{ /// Create a [tensor output](TensorOutput) from the state. - pub fn new(state: &mut GA::State) -> TensorOutput { - TensorOutput:: { state } + pub fn new(state: &mut GA::State) -> TensorOutput { + TensorOutput:: { state } } /// Write the value to tensor at the given coordinate. @@ -776,18 +994,19 @@ pub struct TensorArgs; #[derive(CubeLaunch, CubeType)] /// Input representation for [TensorArgs] implementing [AttentionArgs]. -pub struct TensorInputs { +pub struct TensorInputs { pub query: Tensor>, pub key: Tensor>, pub value: Tensor>, - // pub mask: CubeOption>>, + pub mask: CubeOption>>, } -impl ConcreteInputsFactory for TensorInputs { +impl ConcreteInputsFactory for TensorInputs { fn create<'a, R: Runtime>( query: &'a TensorHandleRef<'a, R>, key: &'a TensorHandleRef<'a, R>, value: &'a TensorHandleRef<'a, R>, + mask: &'a Option>, _selection: &AttentionSelection, _problem: &AttentionProblem, line_sizes: &AttentionLineSizes, @@ -796,7 +1015,10 @@ impl ConcreteInputsFactory for TensorInputs CubeOptionArgs::Some(mask.as_tensor_arg(line_sizes.mask)), + None => CubeOptionArgs::None, + }, ) } } @@ -813,234 +1035,328 @@ impl ConcreteOutputFactory for Tensor> { } #[derive(CubeType)] -pub struct AttentionState { +pub struct AttentionState { pub query: *const Tensor>, pub key: *const Tensor>, pub value: *const Tensor>, + pub mask: CubeOption<*const Tensor>>, pub output: *mut Tensor>, } #[cube] impl AttentionArgs for TensorArgs { - type Input = TensorInputs; + type Input = TensorInputs; type Output = Tensor>; - type State = AttentionState; + type State = AttentionState; - fn init_state( - input: &Self::Input, + fn init_state( + input: &Self::Input, output: &mut Self::Output, - ) -> Self::State { - AttentionState:: { + ) -> Self::State { + let mask = match &input.mask { + CubeOption::None => CubeOption::new_None(), + CubeOption::Some(mask) => { + let ptr: *const Tensor> = mask; + CubeOption::new_Some(ptr) + } + }; + + AttentionState:: { query: &input.query, key: &input.key, value: &input.value, + mask, output, } } - fn read_query( - state: &Self::State, + fn has_mask( + state: &Self::State, + ) -> CubeOption<()> { + match state.mask { + CubeOption::None => CubeOption::new_None(), + CubeOption::Some(_) => CubeOption::new_Some(()), + } + } + + fn read_query( + state: &Self::State, coordinate: u32, ) -> Line { unsafe { (*state.query)[coordinate] } } - fn read_key( - state: &Self::State, + fn read_key( + state: &Self::State, coordinate: u32, ) -> Line { unsafe { (*state.key)[coordinate] } } - fn read_value( - state: &Self::State, + fn read_value( + state: &Self::State, coordinate: u32, ) -> Line { unsafe { (*state.value)[coordinate] } } - fn read_window_query( - state: &Self::State, + fn read_mask( + state: &Self::State, + coordinate: u32, + ) -> Line { + unsafe { (*state.mask.unwrap())[coordinate] } + } + + fn read_window_query( + state: &Self::State, start: u32, end: u32, ) -> Slice> { unsafe { (*state.query).slice(start, end) } } - fn read_window_key( - state: &Self::State, + fn read_window_key( + state: &Self::State, start: u32, end: u32, ) -> Slice> { unsafe { (*state.key).slice(start, end) } } - fn read_window_value( - state: &Self::State, + fn read_window_value( + state: &Self::State, start: u32, end: u32, ) -> Slice> { unsafe { (*state.value).slice(start, end) } } - fn as_tensor_map_query( - _state: &Self::State, + fn read_window_mask( + state: &Self::State, + start: u32, + end: u32, + ) -> Slice> { + unsafe { (*state.mask.unwrap()).slice(start, end) } + } + + fn as_tensor_map_query( + _state: &Self::State, ) -> CubeOption> { CubeOption::new_None() } - fn as_tensor_map_key( - _state: &Self::State, + fn as_tensor_map_key( + _state: &Self::State, ) -> CubeOption> { CubeOption::new_None() } - fn as_tensor_map_value( - _state: &Self::State, + fn as_tensor_map_value( + _state: &Self::State, ) -> CubeOption> { CubeOption::new_None() } - fn shape_query( - state: &Self::State, + fn as_tensor_map_mask( + _state: &Self::State, + ) -> CubeOption> { + CubeOption::new_None() + } + + fn shape_query( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.query).shape(dim) } } - fn shape_key( - state: &Self::State, + fn shape_key( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.key).shape(dim) } } - fn shape_value( - state: &Self::State, + fn shape_value( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.value).shape(dim) } } - fn shape_out( - state: &Self::State, + fn shape_mask( + state: &Self::State, + dim: u32, + ) -> u32 { + unsafe { (*state.mask.unwrap()).shape(dim) } + } + + fn shape_out( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.output).shape(dim) } } - fn stride_query( - state: &Self::State, + fn stride_query( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.query).stride(dim) } } - fn stride_key( - state: &Self::State, + fn stride_key( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.key).stride(dim) } } - fn stride_value( - state: &Self::State, + fn stride_value( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.value).stride(dim) } } - fn stride_out( - state: &Self::State, + fn stride_mask( + state: &Self::State, + dim: u32, + ) -> u32 { + unsafe { (*state.mask.unwrap()).stride(dim) } + } + + fn stride_out( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.output).stride(dim) } } - fn write_out( - state: &mut Self::State, + fn write_out( + state: &mut Self::State, coordinate: u32, value: Line, ) { unsafe { (*state.output)[coordinate] = value } } - fn rank_query(state: &Self::State) -> u32 { + fn rank_query( + state: &Self::State, + ) -> u32 { unsafe { (*state.query).rank() } } - fn rank_key(state: &Self::State) -> u32 { + fn rank_key( + state: &Self::State, + ) -> u32 { unsafe { (*state.key).rank() } } - fn rank_value(state: &Self::State) -> u32 { + fn rank_value( + state: &Self::State, + ) -> u32 { unsafe { (*state.value).rank() } } - fn rank_out(state: &Self::State) -> u32 { + fn rank_mask( + state: &Self::State, + ) -> u32 { + unsafe { (*state.mask.unwrap()).rank() } + } + + fn rank_out( + state: &Self::State, + ) -> u32 { unsafe { (*state.output).rank() } } - fn len_query(state: &Self::State) -> u32 { + fn len_query( + state: &Self::State, + ) -> u32 { unsafe { (*state.query).len() } } - fn len_key(state: &Self::State) -> u32 { + fn len_key( + state: &Self::State, + ) -> u32 { unsafe { (*state.key).len() } } - fn len_value(state: &Self::State) -> u32 { + fn len_value( + state: &Self::State, + ) -> u32 { unsafe { (*state.value).len() } } - fn len_out(state: &Self::State) -> u32 { + fn len_mask( + state: &Self::State, + ) -> u32 { + unsafe { (*state.mask.unwrap()).len() } + } + + fn len_out( + state: &Self::State, + ) -> u32 { unsafe { (*state.output).len() } } - fn buffer_len_query( - state: &Self::State, + fn buffer_len_query( + state: &Self::State, ) -> u32 { unsafe { (*state.query).buffer_len() } } - fn buffer_len_key( - state: &Self::State, + fn buffer_len_key( + state: &Self::State, ) -> u32 { unsafe { (*state.key).buffer_len() } } - fn buffer_len_value( - state: &Self::State, + fn buffer_len_value( + state: &Self::State, ) -> u32 { unsafe { (*state.value).buffer_len() } } - fn buffer_len_out( - state: &Self::State, + fn buffer_len_mask( + state: &Self::State, + ) -> u32 { + unsafe { (*state.mask.unwrap()).buffer_len() } + } + + fn buffer_len_out( + state: &Self::State, ) -> u32 { unsafe { (*state.output).buffer_len() } } - fn line_size_query( - state: &Self::State, + fn line_size_query( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.query).line_size() } } - fn line_size_key( - state: &Self::State, + fn line_size_key( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.key).line_size() } } - fn line_size_value( - state: &Self::State, + fn line_size_value( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.value).line_size() } } - fn line_size_out( - state: &Self::State, + fn line_size_mask( + state: &Self::State, + ) -> comptime_type!(u32) { + unsafe { (*state.mask.unwrap()).line_size() } + } + + fn line_size_out( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.output).line_size() } } @@ -1049,14 +1365,14 @@ impl AttentionArgs for TensorArgs { mod __query { use super::*; - impl CubeType - for TensorQuery + impl CubeType + for TensorQuery { - type ExpandType = TensorQueryExpand; + type ExpandType = TensorQueryExpand; } - impl Clone - for TensorQueryExpand + impl Clone + for TensorQueryExpand { fn clone(&self) -> Self { Self { @@ -1065,30 +1381,30 @@ mod __query { } } - impl IntoMut - for TensorQueryExpand + impl IntoMut + for TensorQueryExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); self } } - impl CubeDebug - for TensorQueryExpand + impl CubeDebug + for TensorQueryExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Clone - for TensorQuery + impl Clone + for TensorQuery { fn clone(&self) -> Self { *self } } - impl Copy - for TensorQuery + impl Copy + for TensorQuery { } } @@ -1096,14 +1412,14 @@ mod __query { mod __key { use super::*; - impl CubeType - for TensorKey + impl CubeType + for TensorKey { - type ExpandType = TensorKeyExpand; + type ExpandType = TensorKeyExpand; } - impl Clone - for TensorKeyExpand + impl Clone + for TensorKeyExpand { fn clone(&self) -> Self { Self { @@ -1112,42 +1428,92 @@ mod __key { } } - impl IntoMut - for TensorKeyExpand + impl IntoMut + for TensorKeyExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); self } } - impl CubeDebug - for TensorKeyExpand + impl CubeDebug + for TensorKeyExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Clone - for TensorKey + impl Clone + for TensorKey { fn clone(&self) -> Self { *self } } - impl Copy for TensorKey {} + impl Copy + for TensorKey + { + } } mod __value { use super::*; - impl CubeType - for TensorValue + impl CubeType + for TensorValue + { + type ExpandType = TensorValueExpand; + } + + impl Clone + for TensorValueExpand + { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + } + } + } + + impl IntoMut + for TensorValueExpand + { + fn into_mut(mut self, scope: &mut Scope) -> Self { + self.state = self.state.into_mut(scope); + self + } + } + impl CubeDebug + for TensorValueExpand + { + fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { + self.state.set_debug_name(scope, name); + } + } + impl Clone + for TensorValue + { + fn clone(&self) -> Self { + *self + } + } + impl Copy + for TensorValue + { + } +} + +mod __mask { + use super::*; + + impl CubeType + for TensorMask { - type ExpandType = TensorValueExpand; + type ExpandType = TensorMaskExpand; } - impl Clone - for TensorValueExpand + impl Clone + for TensorMaskExpand { fn clone(&self) -> Self { Self { @@ -1156,30 +1522,30 @@ mod __value { } } - impl IntoMut - for TensorValueExpand + impl IntoMut + for TensorMaskExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); self } } - impl CubeDebug - for TensorValueExpand + impl CubeDebug + for TensorMaskExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Clone - for TensorValue + impl Clone + for TensorMask { fn clone(&self) -> Self { *self } } - impl Copy - for TensorValue + impl Copy + for TensorMask { } } @@ -1187,22 +1553,22 @@ mod __value { mod __output { use super::*; - impl CubeType - for TensorOutput + impl CubeType + for TensorOutput { - type ExpandType = TensorOutputExpand; + type ExpandType = TensorOutputExpand; } - impl Clone - for TensorOutput + impl Clone + for TensorOutput { fn clone(&self) -> Self { *self } } - impl Clone - for TensorOutputExpand + impl Clone + for TensorOutputExpand { fn clone(&self) -> Self { Self { @@ -1211,8 +1577,8 @@ mod __output { } } - impl IntoMut - for TensorOutputExpand + impl IntoMut + for TensorOutputExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); @@ -1220,16 +1586,16 @@ mod __output { } } - impl CubeDebug - for TensorOutputExpand + impl CubeDebug + for TensorOutputExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Copy - for TensorOutput + impl Copy + for TensorOutput { } } diff --git a/crates/cubecl-attention/src/components/batch/base.rs b/crates/cubecl-attention/src/components/batch/base.rs index 50e44ec8e..ce12b4d5a 100644 --- a/crates/cubecl-attention/src/components/batch/base.rs +++ b/crates/cubecl-attention/src/components/batch/base.rs @@ -1,6 +1,6 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, @@ -61,6 +61,7 @@ pub trait BatchAttention: 'static + Send + Sync { query: VirtualTensor>, key: VirtualTensor>, value: VirtualTensor>, + mask: CubeOption>>, out: VirtualTensor, ReadWrite>, cube_count_args: CubeCountInput, #[comptime] config: Self::Config, diff --git a/crates/cubecl-attention/src/components/batch/dummy/attention.rs b/crates/cubecl-attention/src/components/batch/dummy/attention.rs index 7a4e480ee..24c1678d8 100644 --- a/crates/cubecl-attention/src/components/batch/dummy/attention.rs +++ b/crates/cubecl-attention/src/components/batch/dummy/attention.rs @@ -1,6 +1,6 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; use std::marker::PhantomData; use crate::components::{ @@ -26,6 +26,7 @@ impl, AP: AttentionPrecision> BatchAttention query: VirtualTensor>, key: VirtualTensor>, value: VirtualTensor>, + mask: CubeOption>>, out: VirtualTensor, ReadWrite>, _cube_count_args: CubeCountInput, #[comptime] config: Self::Config, @@ -46,6 +47,7 @@ impl, AP: AttentionPrecision> BatchAttention GA::init_query_reader(q_offset, query, global_config), GA::init_key_reader(key, global_config), GA::init_value_reader(value, global_config), + GA::init_mask_reader(q_offset, mask, global_config), GA::init_writer(q_offset, out, global_config), seq_q, seq_kv, diff --git a/crates/cubecl-attention/src/components/batch/dummy/config.rs b/crates/cubecl-attention/src/components/batch/dummy/config.rs index 47741bdec..296e6ca3b 100644 --- a/crates/cubecl-attention/src/components/batch/dummy/config.rs +++ b/crates/cubecl-attention/src/components/batch/dummy/config.rs @@ -10,7 +10,7 @@ use crate::components::{ pub struct DummyBatchConfig { global_config: G, hypercube_config: HypercubeConfig, - seq_k: u32, + seq_kv: u32, } impl BatchAttentionConfig for DummyBatchConfig { @@ -30,11 +30,11 @@ impl BatchAttentionConfig for DummyBatchConfig { } impl DummyBatchConfig { - pub fn new(global_config: G, hypercube_config: HypercubeConfig, seq_k: u32) -> Self { + pub fn new(global_config: G, hypercube_config: HypercubeConfig, seq_kv: u32) -> Self { Self { global_config, hypercube_config, - seq_k, + seq_kv, } } diff --git a/crates/cubecl-attention/src/components/batch/entry_point.rs b/crates/cubecl-attention/src/components/batch/entry_point.rs index 10c1e8e6c..8a7a74d57 100644 --- a/crates/cubecl-attention/src/components/batch/entry_point.rs +++ b/crates/cubecl-attention/src/components/batch/entry_point.rs @@ -1,5 +1,6 @@ use crate::components::args::AttentionArgs; use crate::components::args::TensorKey; +use crate::components::args::TensorMask; use crate::components::args::TensorOutput; use crate::components::args::TensorQuery; use crate::components::args::TensorValue; @@ -9,8 +10,9 @@ use crate::components::batch::base::BatchAttention; use cubecl_core as cubecl; use cubecl_core::prelude::*; use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, CubeOptionExpand}; -type Input = ::Input; +type Input = ::Input; type Output = ::Output; #[cube(launch_unchecked)] @@ -31,27 +33,41 @@ pub(crate) fn attention< OS: Float, BMMF: BatchAttentionFamily, >( - inputs: &Input, + inputs: &Input, output: &mut Output, cube_count_args: CubeCountInput, #[comptime] config: BMMF::Config, ) { let mut state = Args::init_state(inputs, output); - let query = TensorQuery::::new(&state); - let key = TensorKey::::new(&state); - let value = TensorValue::::new(&state); - let mut out = TensorOutput::::new(&mut state); + let query = TensorQuery::::new(&state); + let query = VirtualTensor::::new::>(&query); - let query = VirtualTensor::::new::>(&query); - let key = VirtualTensor::::new::>(&key); - let value = VirtualTensor::::new::>(&value); - let out = VirtualTensor::::new::>(&mut out); + let key = TensorKey::::new(&state); + let key = VirtualTensor::::new::>(&key); + + let value = TensorValue::::new(&state); + let value = VirtualTensor::::new::>(&value); + + let has_mask = Args::has_mask(&state); + let mask: CubeOption> = match has_mask { + CubeOption::Some(_) => { + let mask = TensorMask::::new(&state); + let mask = VirtualTensor::::new::>(&mask); + CubeOption::new_Some(mask) + } + CubeOption::None => CubeOption::new_None(), + }; + + let mut out = TensorOutput::::new(&mut state); + let out = + VirtualTensor::::new::>(&mut out); BMMF::Attention::<(QG, QT, KG, KS, VG, VS, KVT, SM, ACC, MSK, OG, OS)>::execute( query, key, value, + mask, out, cube_count_args, config, diff --git a/crates/cubecl-attention/src/components/global/base.rs b/crates/cubecl-attention/src/components/global/base.rs index ae1e92ad8..bdea4b700 100644 --- a/crates/cubecl-attention/src/components/global/base.rs +++ b/crates/cubecl-attention/src/components/global/base.rs @@ -1,7 +1,9 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; + +use crate::components::global::dummy::AttentionReader; use cubecl_matmul::components::{global::memory::GlobalMemoryConfig, stage::StageMemoryConfig}; -use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; use crate::components::{ AttentionIdent, AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, @@ -42,9 +44,11 @@ pub trait GlobalAttention: 'static + Send + Sync { type Writer: CubeType; /// Loads to SMEM transposed - type KeyReader: CubeType; + type KeyReader: AttentionReader, Self::Config>; + /// Loads to SMEM as is + type ValueReader: AttentionReader, Self::Config>; /// Loads to SMEM as is - type ValueReader: CubeType; + type MaskReader: CubeType; /// The configuration type associated with this Attention. type Config: GlobalAttentionConfig; @@ -53,6 +57,7 @@ pub trait GlobalAttention: 'static + Send + Sync { query_reader: QueryReader, key_reader: Self::KeyReader, value_reader: Self::ValueReader, + mask_reader: Self::MaskReader, writer: Self::Writer, seq_q: u32, seq_kv: u32, @@ -75,6 +80,12 @@ pub trait GlobalAttention: 'static + Send + Sync { #[comptime] config: Self::Config, ) -> Self::ValueReader; + fn init_mask_reader( + q_offset: u32, + mask: CubeOption>>, + #[comptime] config: Self::Config, + ) -> Self::MaskReader; + fn init_writer( q_offset: u32, out: VirtualTensor, ReadWrite>, @@ -97,4 +108,6 @@ pub trait GlobalAttentionConfig: fn global_memory_config(&self, ident: AttentionIdent) -> GlobalMemoryConfig; fn tiling_scheme(&self) -> AttentionTilingScheme; + + fn causal_mask(&self) -> bool; } diff --git a/crates/cubecl-attention/src/components/global/dummy/attention.rs b/crates/cubecl-attention/src/components/global/dummy/attention.rs index e82f996f9..e18752ef5 100644 --- a/crates/cubecl-attention/src/components/global/dummy/attention.rs +++ b/crates/cubecl-attention/src/components/global/dummy/attention.rs @@ -1,12 +1,15 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_matmul::components::{global::PartitionedStage, stage::StridedStage}; +use cubecl_matmul::components::global::PartitionedStage; +use cubecl_matmul::components::stage::StridedStage; use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, CubeOptionExpand}; use std::marker::PhantomData; -use crate::components::GlobalMask; use crate::components::attention_types::*; use crate::components::global::base::GlobalAttentionConfig; +use crate::components::global::dummy::MaskReader; +use crate::components::global::dummy::reader::{AttentionReader, AttentionReaderExpand}; use crate::components::global::dummy::writer::DummyWriter; use crate::components::global::{ AttentionGlobalLayout, @@ -37,6 +40,7 @@ impl< { type KeyReader = DummyKeyReader; type ValueReader = DummyValueReader; + type MaskReader = MaskReader; type Writer = DummyWriter<(OG, OS)>; @@ -46,52 +50,67 @@ impl< query_reader: QueryReader, mut key_reader: Self::KeyReader, mut value_reader: Self::ValueReader, + mut mask_reader: Self::MaskReader, mut writer: Self::Writer, seq_q: u32, seq_kv: u32, #[comptime] config: Self::Config, ) { - let key_stage = key_reader.stage(); - let value_stage = value_reader.stage(); + let mut key_stage = key_reader.init_stage(config); + let mut value_stage = value_reader.init_stage(config); - let mut stage_state = SA::init_state(config.stage_config()); + let mut query_registers = SA::init_query(config.stage_config()); + let mut key_value_registers = SA::init_key_value(config.stage_config()); + let mut mask_registers = + SA::init_mask(CubeOption::new_Some((seq_q, seq_kv)), config.stage_config()); + let mut softmax_registers = SA::init_softmax(config.stage_config()); + let mut accumulator_registers = SA::init_accumulator(config.stage_config()); - let (query, mut key_value, mut softmax, mut accumulator) = - SA::init_partitions(query_reader, config.stage_config()); + let mut stage_state = SA::init_state(config.stage_config()); let seq_kv_stage = config.tiling_scheme().elements_in_partition_seq_kv(); let num_stage_iterations = seq_kv.div_ceil(seq_kv_stage); - let mask = GlobalMask::new(seq_q, seq_kv, config.tiling_scheme()); - for i in 0..num_stage_iterations { - key_reader.read_transposed(config); - value_reader.read(config); + SA::read_query(&query_reader, &mut query_registers, config.stage_config()); + + for _ in 0..num_stage_iterations { + key_reader.read_global(&mut key_stage, config); + value_reader.read_global(&mut value_stage, config); + + SA::read_mask(&mask_reader, &mut mask_registers, config.stage_config()); + sync_cube(); SA::execute( + &query_registers, &key_stage, &value_stage, - &query, - &mut key_value, - &mut softmax, - mask.to_stage(CUBE_POS, i), - &mut accumulator, + &mut key_value_registers, + &mask_registers, + &mut softmax_registers, + &mut accumulator_registers, &mut stage_state, config.stage_config(), ); sync_cube(); + key_reader.advance_view(); value_reader.advance_view(); + mask_reader.advance_view(); } - SA::rescale(&mut accumulator, stage_state, config.stage_config()); + SA::rescale( + &mut accumulator_registers, + stage_state, + config.stage_config(), + ); let mut out_stage = writer.stage(); SA::write::( - &accumulator, + &accumulator_registers, &mut out_stage, &mut writer, config.stage_config(), @@ -119,7 +138,7 @@ impl< let step = reduction_step::(config); let layout = AttentionGlobalLayout::new(&key, 0, config.global_memory_config(AttentionIdent::Key)); - DummyKeyReader::new(key.view(layout), step, config) + DummyKeyReader::new(key.view(layout), step) } fn init_value_reader( @@ -132,7 +151,28 @@ impl< 0, config.global_memory_config(AttentionIdent::Value), ); - DummyValueReader::new(value.view(layout), step, config) + DummyValueReader::new(value.view(layout), step) + } + + fn init_mask_reader( + q_offset: u32, + mask: CubeOption>>, + #[comptime] config: Self::Config, + ) -> Self::MaskReader { + let step = reduction_step::(config); + + match mask { + CubeOption::Some(mask) => { + let layout = AttentionGlobalLayout::new( + &mask, + 0, + config.global_memory_config(AttentionIdent::Value), + ); + + MaskReader::new_materialized(q_offset, mask.view(layout), step) + } + CubeOption::None => MaskReader::new_logical(q_offset, step), + } } fn init_writer( diff --git a/crates/cubecl-attention/src/components/global/dummy/config.rs b/crates/cubecl-attention/src/components/global/dummy/config.rs index 6fecc75c7..98d7ee6ef 100644 --- a/crates/cubecl-attention/src/components/global/dummy/config.rs +++ b/crates/cubecl-attention/src/components/global/dummy/config.rs @@ -12,6 +12,7 @@ use crate::components::{ pub struct DummyGlobalConfig { stage_config: S, num_planes: u32, + causal_mask: bool, } impl GlobalAttentionConfig for DummyGlobalConfig { @@ -66,13 +67,22 @@ impl GlobalAttentionConfig for DummyGlobalConfig { fn tiling_scheme(&self) -> AttentionTilingScheme { self.stage_config.tiling_scheme() } + + fn causal_mask(&self) -> bool { + self.causal_mask + } } impl DummyGlobalConfig { - pub fn new(stage_config: S, num_planes: u32) -> Result { + pub fn new( + stage_config: S, + num_planes: u32, + causal_mask: bool, + ) -> Result { Self { stage_config, num_planes, + causal_mask, } .validate() } diff --git a/crates/cubecl-attention/src/components/global/dummy/mod.rs b/crates/cubecl-attention/src/components/global/dummy/mod.rs index 5514f8659..19c496256 100644 --- a/crates/cubecl-attention/src/components/global/dummy/mod.rs +++ b/crates/cubecl-attention/src/components/global/dummy/mod.rs @@ -1,12 +1,9 @@ mod attention; mod config; -mod read; +mod reader; mod setup; mod writer; pub use attention::*; -pub use read::*; +pub use reader::*; pub use setup::DummyGlobalAttentionFamily; - -// tmp -pub use config::DummyGlobalConfig; diff --git a/crates/cubecl-attention/src/components/global/dummy/read.rs b/crates/cubecl-attention/src/components/global/dummy/read.rs deleted file mode 100644 index 870d376d2..000000000 --- a/crates/cubecl-attention/src/components/global/dummy/read.rs +++ /dev/null @@ -1,228 +0,0 @@ -use crate::components::attention_types::*; -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::global::{ - memory::{GlobalIterator, ViewDirection}, - read::tiled::TiledLayout, -}; -use cubecl_matmul::components::stage::StridedStage; -use cubecl_matmul::components::tile::StridedTile; -use cubecl_matmul::components::{MatrixLayout, StageIdent}; -use cubecl_std::tensor::{View, layout::Coords2d}; -use std::marker::PhantomData; - -use crate::components::global::base::GlobalAttentionConfig; -use crate::components::stage::StageAttentionConfig; -use crate::components::tile::AttentionTilingLayout; -use crate::components::{AttentionIdent, AttentionPrecision}; - -#[derive(CubeType)] -pub struct QueryReader { - query: View>, Coords2d>, -} - -#[derive(CubeType)] -pub struct DummyKeyReader { - global_iter: GlobalIterator>>, - stage_memory: StridedStage, AttentionTilingLayout>, - - #[cube(comptime)] - _phantom: PhantomData, -} - -#[derive(CubeType)] -pub struct DummyValueReader { - global_iter: GlobalIterator>>, - stage_memory: StridedStage, AttentionTilingLayout>, - - #[cube(comptime)] - _phantom: PhantomData, -} - -#[cube] -impl QueryReader { - pub fn new(q_offset: u32, query: View>, Coords2d>) -> Self { - let query = query.slice((q_offset, 0), query.shape()); - - QueryReader:: { query } - } - - pub fn get_tile( - &self, - tile: Coords2d, - #[comptime] config: S, - ) -> StridedTile> { - let (row_in_partition, col) = tile; - let attention_tile_size = config.tiling_scheme().tile_size; - - let row = row_in_partition + UNIT_POS_Y * config.tiling_scheme().partition_size.seq_q; - - StridedTile::>::new_strided( - self.query - .slice( - ( - row * attention_tile_size.seq_q, - col * attention_tile_size.head_dim, - ), - (attention_tile_size.seq_q, attention_tile_size.head_dim).runtime(), - ) - .to_linear_slice(), - config.tiling_scheme().elements_in_partition_head_dim(), - MatrixLayout::RowMajor, - ) - } -} - -#[cube] -impl DummyKeyReader { - pub fn new(key: View>, Coords2d>, step: u32, #[comptime] config: G) -> Self { - let global_iter = GlobalIterator::new(key, step, ViewDirection::Row, false); - let stage_memory = StridedStage::new(StageIdent::Rhs, config.score_stage_memory_config()); - - DummyKeyReader:: { - global_iter, - stage_memory, - _phantom: PhantomData, - } - } - - pub fn stage(&self) -> StridedStage, AttentionTilingLayout> { - self.stage_memory - } - - pub fn read_transposed(&mut self, #[comptime] config: G) { - // TODO this reader is bad - if UNIT_POS_Y == 0 { - let memory_config = config.global_memory_config(AttentionIdent::Key); - - let mut slice = self.stage_memory.as_slice_mut(1u32); - - let tile_rows_load = memory_config.elements_in_tile_row; - let tile_cols_load = memory_config.elements_in_tile_col; - let partition_rows_load = memory_config.elements_in_stage_row / tile_rows_load; - let partition_cols_load = memory_config.elements_in_stage_col / tile_cols_load; - - let units_per_tile_row = comptime!(config.plane_dim() / tile_rows_load); - let tile_cols_per_unit = comptime!(div_ceil(tile_cols_load, units_per_tile_row)); - - let row_load_in_tile = UNIT_POS_X / units_per_tile_row; - let col_load_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; - - // Assumes row tiling order - let num_elements_per_tile = tile_rows_load * tile_cols_load; - let tile_row_stride_store = partition_rows_load * num_elements_per_tile; - let tile_col_stride_store = num_elements_per_tile; - - let layout = TiledLayout::new(memory_config); - let view = self.global_iter.view().view(layout); - - #[unroll] - for tile_row_load in 0..partition_rows_load { - #[unroll] - for tile_col_load in 0..partition_cols_load { - if row_load_in_tile < tile_rows_load { - #[unroll] - for i in 0..tile_cols_per_unit { - let col_load = col_load_in_tile_start + i; - - if col_load < tile_cols_load { - let tile_row_store = tile_col_load; - let tile_col_store = tile_row_load; - let tile_row_store_offset = tile_row_store * tile_row_stride_store; - let tile_col_store_offset = tile_col_store * tile_col_stride_store; - let store_offset = tile_row_store_offset + tile_col_store_offset; - - let index_load = row_load_in_tile * tile_cols_load + col_load; - let index_store = col_load * tile_rows_load + row_load_in_tile; - - slice[index_store + store_offset] = - Line::cast_from(view.read_checked(( - (tile_row_load, tile_col_load).runtime(), - index_load, - ))); - } - } - } - } - } - } - } - - pub fn advance_view(&mut self) { - self.global_iter.advance(); - } -} - -#[cube] -impl DummyValueReader { - pub fn new(value: View>, Coords2d>, step: u32, #[comptime] config: G) -> Self { - let global_iter = GlobalIterator::new(value, step, ViewDirection::Row, false); - let stage_memory = StridedStage::new(StageIdent::Rhs, config.value_stage_memory_config()); - - DummyValueReader:: { - global_iter, - stage_memory, - _phantom: PhantomData, - } - } - - pub fn stage(&self) -> StridedStage, AttentionTilingLayout> { - self.stage_memory - } - - pub fn read(&mut self, #[comptime] config: G) { - if UNIT_POS_Y == 0 { - // TODO this reader is bad, it's not coalesced - let memory_config = config.global_memory_config(AttentionIdent::Value); - let mut slice = self.stage_memory.as_slice_mut(1u32); - - let tile_rows = memory_config.elements_in_tile_row; - let tile_cols = memory_config.elements_in_tile_col; - let partition_rows = memory_config.elements_in_stage_row / tile_rows; - let partition_cols = memory_config.elements_in_stage_col / tile_cols; - - let units_per_tile_row = comptime!(config.plane_dim() / tile_rows); - let tile_cols_per_unit = comptime!(div_ceil(tile_cols, units_per_tile_row)); - - let row_in_tile = UNIT_POS_X / units_per_tile_row; - let col_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; - - // Assumes row tiling order - let num_elements_per_tile = tile_rows * tile_cols; - let tile_row_stride = partition_cols * num_elements_per_tile; - let tile_col_stride = num_elements_per_tile; - - let layout = TiledLayout::new(memory_config); - let view = self.global_iter.view().view(layout); - - #[unroll] - for tile_row in 0..partition_rows { - #[unroll] - for tile_col in 0..partition_cols { - if row_in_tile < tile_rows { - #[unroll] - for i in 0..tile_cols_per_unit { - let col = col_in_tile_start + i; - - if col < tile_cols { - let tile_row_offset = tile_row * tile_row_stride; - let tile_col_offset = tile_col * tile_col_stride; - let offset = tile_row_offset + tile_col_offset; - - let index = row_in_tile * tile_cols + col; - - slice[index + offset] = Line::cast_from( - view.read_checked(((tile_row, tile_col).runtime(), index)), - ); - } - } - } - } - } - } - } - - pub fn advance_view(&mut self) { - self.global_iter.advance(); - } -} diff --git a/crates/cubecl-attention/src/components/global/dummy/reader/base.rs b/crates/cubecl-attention/src/components/global/dummy/reader/base.rs new file mode 100644 index 000000000..4981a876c --- /dev/null +++ b/crates/cubecl-attention/src/components/global/dummy/reader/base.rs @@ -0,0 +1,15 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::global::GlobalAttentionConfig; + +#[cube] +pub trait AttentionReader { + type Stage: CubeType; + + fn init_stage(&mut self, #[comptime] config: G) -> Self::Stage; + + fn read_global(&mut self, stage: &mut Self::Stage, #[comptime] config: G); + + fn advance_view(&mut self); +} diff --git a/crates/cubecl-attention/src/components/global/dummy/reader/key.rs b/crates/cubecl-attention/src/components/global/dummy/reader/key.rs new file mode 100644 index 000000000..10fe0a980 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/dummy/reader/key.rs @@ -0,0 +1,109 @@ +use crate::components::attention_types::*; +use crate::components::global::dummy::reader::{AttentionReader, AttentionReaderExpand}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::StageIdent; +use cubecl_matmul::components::global::{ + memory::{GlobalIterator, ViewDirection}, + read::tiled::TiledLayout, +}; +use cubecl_matmul::components::stage::StridedStage; +use cubecl_std::tensor::{View, layout::Coords2d}; +use std::marker::PhantomData; + +use crate::components::global::base::GlobalAttentionConfig; +use crate::components::tile::AttentionTilingLayout; +use crate::components::{AttentionIdent, AttentionPrecision}; + +#[derive(CubeType)] +pub struct DummyKeyReader { + global_iter: GlobalIterator>>, + + #[cube(comptime)] + _phantom: PhantomData, +} + +#[cube] +impl DummyKeyReader { + pub fn new(key: View>, Coords2d>, step: u32) -> Self { + let global_iter = GlobalIterator::new(key, step, ViewDirection::Row, false); + + DummyKeyReader:: { + global_iter, + _phantom: PhantomData, + } + } +} + +#[cube] +impl AttentionReader, G> + for DummyKeyReader +{ + type Stage = StridedStage, AttentionTilingLayout>; + + fn init_stage(&mut self, #[comptime] config: G) -> Self::Stage { + StridedStage::new(StageIdent::Rhs, config.score_stage_memory_config()) + } + + fn read_global(&mut self, stage: &mut Self::Stage, #[comptime] config: G) { + // TODO this reader is bad + if UNIT_POS_Y == 0 { + let memory_config = config.global_memory_config(AttentionIdent::Key); + + let mut slice = stage.as_slice_mut(1u32); + + let tile_rows_load = memory_config.elements_in_tile_row; + let tile_cols_load = memory_config.elements_in_tile_col; + let partition_rows_load = memory_config.elements_in_stage_row / tile_rows_load; + let partition_cols_load = memory_config.elements_in_stage_col / tile_cols_load; + + let units_per_tile_row = comptime!(config.plane_dim() / tile_rows_load); + let tile_cols_per_unit = comptime!(div_ceil(tile_cols_load, units_per_tile_row)); + + let row_load_in_tile = UNIT_POS_X / units_per_tile_row; + let col_load_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; + + // Assumes row tiling order + let num_elements_per_tile = tile_rows_load * tile_cols_load; + let tile_row_stride_store = partition_rows_load * num_elements_per_tile; + let tile_col_stride_store = num_elements_per_tile; + + let layout = TiledLayout::new(memory_config); + let view = self.global_iter.view().view(layout); + + #[unroll] + for tile_row_load in 0..partition_rows_load { + #[unroll] + for tile_col_load in 0..partition_cols_load { + if row_load_in_tile < tile_rows_load { + #[unroll] + for i in 0..tile_cols_per_unit { + let col_load = col_load_in_tile_start + i; + + if col_load < tile_cols_load { + let tile_row_store = tile_col_load; + let tile_col_store = tile_row_load; + let tile_row_store_offset = tile_row_store * tile_row_stride_store; + let tile_col_store_offset = tile_col_store * tile_col_stride_store; + let store_offset = tile_row_store_offset + tile_col_store_offset; + + let index_load = row_load_in_tile * tile_cols_load + col_load; + let index_store = col_load * tile_rows_load + row_load_in_tile; + + slice[index_store + store_offset] = + Line::cast_from(view.read_checked(( + (tile_row_load, tile_col_load).runtime(), + index_load, + ))); + } + } + } + } + } + } + } + + fn advance_view(&mut self) { + self.global_iter.advance(); + } +} diff --git a/crates/cubecl-attention/src/components/global/dummy/reader/mask.rs b/crates/cubecl-attention/src/components/global/dummy/reader/mask.rs new file mode 100644 index 000000000..c19bde262 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/dummy/reader/mask.rs @@ -0,0 +1,111 @@ +use crate::components::attention_types::*; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::MatrixLayout; +use cubecl_matmul::components::global::memory::{GlobalIterator, ViewDirection}; +use cubecl_matmul::components::tile::StridedTile; +use cubecl_std::tensor::{View, layout::Coords2d}; + +use crate::components::AttentionPrecision; +use crate::components::stage::StageAttentionConfig; +use cubecl_std::CubeOption; + +#[derive(CubeType)] +pub struct LogicalIterator { + row: u32, + col: RuntimeCell, + step_col: u32, +} + +#[cube] +impl LogicalIterator { + fn init(q_offset: u32, step_col: u32) -> LogicalIterator { + LogicalIterator { + row: q_offset, + col: RuntimeCell::new(0), + step_col, + } + } + + fn read(&self) -> Coords2d { + (self.row, self.col.read()) + } + + fn advance(&mut self) { + self.col.store(self.col.read() + self.step_col); + } +} + +#[derive(CubeType)] +pub enum MaskReader { + Materialized(GlobalIterator>>, LogicalIterator), + Logical(LogicalIterator), +} + +#[cube] +impl MaskReader { + pub fn new_logical(q_offset: u32, step: u32) -> Self { + MaskReader::::new_Logical(LogicalIterator::init(q_offset, step)) + } + + pub fn new_materialized(q_offset: u32, mask: View>, Coords2d>, step: u32) -> Self { + let mask = mask.slice((q_offset, 0), mask.shape()); + let global_iter = GlobalIterator::new(mask, step, ViewDirection::Col, false); + + MaskReader::::new_Materialized(global_iter, LogicalIterator::init(q_offset, step)) + } + + // TODO read tile too + pub fn read( + &self, + #[comptime] pos_in_partition: Coords2d, + #[comptime] config: S, + ) -> (Coords2d, CubeOption>>) { + match self { + MaskReader::Materialized(global_iterator, logical_iterator) => ( + logical_iterator.read(), + CubeOption::new_Some(get_tile::(global_iterator, pos_in_partition, config)), + ), + MaskReader::Logical(logical_iterator) => { + (logical_iterator.read(), CubeOption::new_None()) + } + } + } + + pub fn advance_view(&mut self) { + match self { + MaskReader::Logical(logical_iter) => logical_iter.advance(), + MaskReader::Materialized(global_iter, logical_iter) => { + global_iter.advance(); + logical_iter.advance() + } + } + } +} + +#[cube] +pub fn get_tile( + global_iter: &GlobalIterator>>, + #[comptime] tile: Coords2d, + #[comptime] config: S, +) -> StridedTile> { + let (row_in_partition, col) = tile; + let attention_tile_size = config.tiling_scheme().tile_size; + + let row = row_in_partition + UNIT_POS_Y * config.tiling_scheme().partition_size.seq_q; + + StridedTile::>::new_strided( + global_iter + .view() + .slice( + ( + row * attention_tile_size.seq_q, + col.runtime() * attention_tile_size.seq_kv, + ), + (attention_tile_size.seq_q, attention_tile_size.seq_kv).runtime(), + ) + .to_linear_slice(), + config.tiling_scheme().elements_in_partition_seq_kv(), + MatrixLayout::RowMajor, + ) +} diff --git a/crates/cubecl-attention/src/components/global/dummy/reader/mod.rs b/crates/cubecl-attention/src/components/global/dummy/reader/mod.rs new file mode 100644 index 000000000..577a6f2e5 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/dummy/reader/mod.rs @@ -0,0 +1,11 @@ +mod base; +mod key; +mod mask; +mod query; +mod value; + +pub use base::*; +pub use key::*; +pub use mask::*; +pub use query::*; +pub use value::*; diff --git a/crates/cubecl-attention/src/components/global/dummy/reader/query.rs b/crates/cubecl-attention/src/components/global/dummy/reader/query.rs new file mode 100644 index 000000000..25b4fb765 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/dummy/reader/query.rs @@ -0,0 +1,48 @@ +use crate::components::attention_types::*; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::MatrixLayout; +use cubecl_matmul::components::tile::StridedTile; +use cubecl_std::tensor::{View, layout::Coords2d}; + +use crate::components::AttentionPrecision; +use crate::components::stage::StageAttentionConfig; + +#[derive(CubeType)] +pub struct QueryReader { + query: View>, Coords2d>, +} + +#[cube] +impl QueryReader { + pub fn new(q_offset: u32, query: View>, Coords2d>) -> Self { + let query = query.slice((q_offset, 0), query.shape()); + + QueryReader:: { query } + } + + pub fn get_tile( + &self, + tile: Coords2d, + #[comptime] config: S, + ) -> StridedTile> { + let (row_in_partition, col) = tile; + let attention_tile_size = config.tiling_scheme().tile_size; + + let row = row_in_partition + UNIT_POS_Y * config.tiling_scheme().partition_size.seq_q; + + StridedTile::>::new_strided( + self.query + .slice( + ( + row * attention_tile_size.seq_q, + col * attention_tile_size.head_dim, + ), + (attention_tile_size.seq_q, attention_tile_size.head_dim).runtime(), + ) + .to_linear_slice(), + config.tiling_scheme().elements_in_partition_head_dim(), + MatrixLayout::RowMajor, + ) + } +} diff --git a/crates/cubecl-attention/src/components/global/dummy/reader/value.rs b/crates/cubecl-attention/src/components/global/dummy/reader/value.rs new file mode 100644 index 000000000..790acbbc2 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/dummy/reader/value.rs @@ -0,0 +1,103 @@ +use crate::components::attention_types::*; +use crate::components::global::dummy::reader::{AttentionReader, AttentionReaderExpand}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::StageIdent; +use cubecl_matmul::components::global::{ + memory::{GlobalIterator, ViewDirection}, + read::tiled::TiledLayout, +}; +use cubecl_matmul::components::stage::StridedStage; +use cubecl_std::tensor::{View, layout::Coords2d}; +use std::marker::PhantomData; + +use crate::components::global::base::GlobalAttentionConfig; +use crate::components::tile::AttentionTilingLayout; +use crate::components::{AttentionIdent, AttentionPrecision}; + +#[derive(CubeType)] +pub struct DummyValueReader { + global_iter: GlobalIterator>>, + + #[cube(comptime)] + _phantom: PhantomData, +} + +#[cube] +impl DummyValueReader { + pub fn new(value: View>, Coords2d>, step: u32) -> Self { + let global_iter = GlobalIterator::new(value, step, ViewDirection::Row, false); + + DummyValueReader:: { + global_iter, + _phantom: PhantomData, + } + } +} + +#[cube] +impl AttentionReader, G> + for DummyValueReader +{ + type Stage = StridedStage, AttentionTilingLayout>; + + fn init_stage(&mut self, #[comptime] config: G) -> Self::Stage { + StridedStage::new(StageIdent::Rhs, config.value_stage_memory_config()) + } + + fn read_global(&mut self, stage: &mut Self::Stage, #[comptime] config: G) { + if UNIT_POS_Y == 0 { + // TODO this reader is bad, it's not coalesced + let memory_config = config.global_memory_config(AttentionIdent::Value); + let mut slice = stage.as_slice_mut(1u32); + + let tile_rows = memory_config.elements_in_tile_row; + let tile_cols = memory_config.elements_in_tile_col; + let partition_rows = memory_config.elements_in_stage_row / tile_rows; + let partition_cols = memory_config.elements_in_stage_col / tile_cols; + + let units_per_tile_row = comptime!(config.plane_dim() / tile_rows); + let tile_cols_per_unit = comptime!(div_ceil(tile_cols, units_per_tile_row)); + + let row_in_tile = UNIT_POS_X / units_per_tile_row; + let col_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; + + // Assumes row tiling order + let num_elements_per_tile = tile_rows * tile_cols; + let tile_row_stride = partition_cols * num_elements_per_tile; + let tile_col_stride = num_elements_per_tile; + + let layout = TiledLayout::new(memory_config); + let view = self.global_iter.view().view(layout); + + #[unroll] + for tile_row in 0..partition_rows { + #[unroll] + for tile_col in 0..partition_cols { + if row_in_tile < tile_rows { + #[unroll] + for i in 0..tile_cols_per_unit { + let col = col_in_tile_start + i; + + if col < tile_cols { + let tile_row_offset = tile_row * tile_row_stride; + let tile_col_offset = tile_col * tile_col_stride; + let offset = tile_row_offset + tile_col_offset; + + let index = row_in_tile * tile_cols + col; + + slice[index + offset] = Line::cast_from( + view.read_checked(((tile_row, tile_col).runtime(), index)), + ); + } + } + } + } + } + } + } + + fn advance_view(&mut self) { + self.global_iter.advance(); + } +} diff --git a/crates/cubecl-attention/src/components/global/dummy/setup.rs b/crates/cubecl-attention/src/components/global/dummy/setup.rs index af70dba7c..b126e9943 100644 --- a/crates/cubecl-attention/src/components/global/dummy/setup.rs +++ b/crates/cubecl-attention/src/components/global/dummy/setup.rs @@ -37,6 +37,6 @@ impl< ) -> Result { let stage_config = SA::setup::(client, problem, selection, line_sizes)?; - DummyGlobalConfig::new(stage_config, stage_config.num_planes()) + DummyGlobalConfig::new(stage_config, stage_config.num_planes(), problem.causal) } } diff --git a/crates/cubecl-attention/src/components/mask.rs b/crates/cubecl-attention/src/components/mask.rs deleted file mode 100644 index ee65b37e1..000000000 --- a/crates/cubecl-attention/src/components/mask.rs +++ /dev/null @@ -1,95 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_std::tensor::layout::Coords2d; - -use crate::components::AttentionTilingScheme; - -#[derive(CubeType, Copy, Clone)] -pub struct GlobalMask { - q_bound: u32, - kv_bound: u32, - #[cube(comptime)] - tiling_scheme: AttentionTilingScheme, -} - -#[derive(CubeType, Copy, Clone)] -pub struct StageMask { - q_bound: u32, - kv_bound: u32, - #[cube(comptime)] - tiling_scheme: AttentionTilingScheme, -} - -#[derive(CubeType, Copy, Clone)] -pub struct PartitionMask { - q_bound: u32, - kv_bound: u32, - #[cube(comptime)] - tiling_scheme: AttentionTilingScheme, -} - -#[derive(CubeType, Copy, Clone)] -pub struct TileMask { - q_bound: u32, - kv_bound: u32, -} - -#[cube] -impl GlobalMask { - pub fn new( - q_bound: u32, - kv_bound: u32, - #[comptime] tiling_scheme: AttentionTilingScheme, - ) -> GlobalMask { - GlobalMask { - q_bound, - kv_bound, - tiling_scheme, - } - } - - pub fn to_stage(&self, row: u32, col: u32) -> StageMask { - let q_factor = comptime!(self.tiling_scheme.elements_in_stage_seq_q()); - let kv_factor = comptime!(self.tiling_scheme.elements_in_stage_seq_kv()); - - StageMask { - q_bound: self.q_bound.saturating_sub(row * q_factor), - kv_bound: self.kv_bound.saturating_sub(col * kv_factor), - tiling_scheme: self.tiling_scheme, - } - } -} - -#[cube] -impl StageMask { - pub fn to_partition(&self, row: u32) -> PartitionMask { - let q_factor = comptime!(self.tiling_scheme.elements_in_partition_seq_q()); - - PartitionMask { - q_bound: self.q_bound.saturating_sub(row * q_factor), - kv_bound: self.kv_bound, - tiling_scheme: self.tiling_scheme, - } - } -} - -#[cube] -impl PartitionMask { - pub fn to_tile(self, row: u32, col: u32) -> TileMask { - let q_factor = comptime!(self.tiling_scheme.elements_in_tile_seq_q()); - let kv_factor = comptime!(self.tiling_scheme.elements_in_tile_seq_kv()); - - TileMask { - q_bound: self.q_bound.saturating_sub(row * q_factor), - kv_bound: self.kv_bound.saturating_sub(col * kv_factor), - } - } -} - -#[cube] -impl TileMask { - pub fn apply(&self, pos: Coords2d) -> E { - let should_mask = E::cast_from(pos.0 >= self.q_bound || pos.1 >= self.kv_bound); - should_mask * E::min_value() - } -} diff --git a/crates/cubecl-attention/src/components/mod.rs b/crates/cubecl-attention/src/components/mod.rs index 8a6d26741..e0cccec2b 100644 --- a/crates/cubecl-attention/src/components/mod.rs +++ b/crates/cubecl-attention/src/components/mod.rs @@ -7,7 +7,6 @@ pub mod tile; mod error; mod ident; mod line_size; -mod mask; mod problem; mod selection; mod spec; @@ -16,7 +15,6 @@ mod tiling_scheme; pub use error::*; pub use ident::*; pub use line_size::*; -pub use mask::*; pub use problem::*; pub use selection::*; pub use spec::*; diff --git a/crates/cubecl-attention/src/components/problem.rs b/crates/cubecl-attention/src/components/problem.rs index d5e66f488..7f457085f 100644 --- a/crates/cubecl-attention/src/components/problem.rs +++ b/crates/cubecl-attention/src/components/problem.rs @@ -16,6 +16,8 @@ pub struct AttentionProblem { /// Usually equal to `head_dim`, but may differ in some variants pub val_dim: usize, - /// Whether a mask is applied (shape is always [batch, seq_q, heads, seq_k]) + /// Whether a mask is supplied (shape is always [batch, seq_q, heads, seq_kv]) pub masked: bool, + /// Whether there is a causal mask + pub causal: bool, } diff --git a/crates/cubecl-attention/src/components/spec.rs b/crates/cubecl-attention/src/components/spec.rs index d599c9717..53469a290 100644 --- a/crates/cubecl-attention/src/components/spec.rs +++ b/crates/cubecl-attention/src/components/spec.rs @@ -205,7 +205,7 @@ impl< } /// Input argument -pub type InputArg = as AttentionArgs>::Input, KG, VG>; +pub type InputArg = as AttentionArgs>::Input, KG, VG, MSK>; /// Output argument pub type OutputArg = as AttentionArgs>::Output>; diff --git a/crates/cubecl-attention/src/components/stage/base.rs b/crates/cubecl-attention/src/components/stage/base.rs index d3b7befb2..a175105b3 100644 --- a/crates/cubecl-attention/src/components/stage/base.rs +++ b/crates/cubecl-attention/src/components/stage/base.rs @@ -7,7 +7,9 @@ use cubecl_matmul::components::{ use std::{fmt::Debug, hash::Hash}; use crate::components::attention_types::*; +use crate::components::global::dummy::MaskReader; use crate::components::stage::dummy::AttentionStageMemoryConfig; +use crate::components::tile::RunningState; use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, AttentionSetupError, AvailableLineSizes, @@ -15,7 +17,8 @@ use crate::components::{ tile::{AttentionTilingLayout, dummy::AttentionMatmulConfig}, }; use crate::components::{AttentionTilingScheme, global::dummy::QueryReader}; -use crate::components::{StageMask, tile::RunningState}; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; /// A family of [TileAttention] implementations that operate with any [precision](AttentionPrecision). pub trait StageAttentionFamily: Send + Sync + 'static { @@ -62,46 +65,57 @@ pub trait StageAttention: 'static + Send + Sync { /// The configuration type associated with this Attention. type Config: StageAttentionConfig; - type QueryPartition: CubeType; - type KeyValuePartition: CubeType; - type SoftmaxPartition: CubeType; - type AccumulatorPartition: CubeType; + type QueryRegisters: CubeType; + type KeyValueRegisters: CubeType; + type SoftmaxRegisters: CubeType; + type AccumulatorRegisters: CubeType; + type MaskRegisters: CubeType; fn init_state(#[comptime] config: Self::Config) -> Sequence>>; fn execute( - key_reader: &Self::KeyStage, - value_reader: &Self::ValueStage, - query: &Self::QueryPartition, - key_value: &mut Self::KeyValuePartition, - score: &mut Self::SoftmaxPartition, - mask: StageMask, - accumulator: &mut Self::AccumulatorPartition, + query: &Self::QueryRegisters, + key_stage: &Self::KeyStage, + value_stage: &Self::ValueStage, + key_value: &mut Self::KeyValueRegisters, + mask_partition: &Self::MaskRegisters, + score: &mut Self::SoftmaxRegisters, + accumulator: &mut Self::AccumulatorRegisters, prev_state: &mut Sequence>>, #[comptime] config: Self::Config, ); fn rescale( - acc: &mut Self::AccumulatorPartition, + acc: &mut Self::AccumulatorRegisters, state: Sequence>>, #[comptime] config: Self::Config, ); fn write( - acc: &Self::AccumulatorPartition, + acc: &Self::AccumulatorRegisters, stage: &mut Self::OutStage, writer: &mut W, #[comptime] tile_config: Self::Config, ); - fn init_partitions( - query_loader: QueryReader, + fn init_query(#[comptime] config: Self::Config) -> Self::QueryRegisters; + fn init_key_value(#[comptime] config: Self::Config) -> Self::KeyValueRegisters; + fn init_mask( + out_of_bounds: CubeOption, + #[comptime] config: Self::Config, + ) -> Self::MaskRegisters; + fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxRegisters; + fn init_accumulator(#[comptime] config: Self::Config) -> Self::AccumulatorRegisters; + + fn read_query( + reader: &QueryReader, + registers: &mut Self::QueryRegisters, + #[comptime] config: Self::Config, + ); + fn read_mask( + reader: &MaskReader, + registers: &mut Self::MaskRegisters, #[comptime] config: Self::Config, - ) -> ( - Self::QueryPartition, - Self::KeyValuePartition, - Self::SoftmaxPartition, - Self::AccumulatorPartition, ); } diff --git a/crates/cubecl-attention/src/components/stage/dummy/attention.rs b/crates/cubecl-attention/src/components/stage/dummy/attention.rs index b7c2ccf00..cb5bf0dc8 100644 --- a/crates/cubecl-attention/src/components/stage/dummy/attention.rs +++ b/crates/cubecl-attention/src/components/stage/dummy/attention.rs @@ -8,14 +8,20 @@ use cubecl_matmul::components::{ use std::marker::PhantomData; use crate::components::attention_types::*; +use crate::components::global::dummy::MaskReader; use crate::components::global::dummy::QueryReader; +use crate::components::stage::dummy::MaskPartition; use crate::components::stage::dummy::SoftmaxPartition; -use crate::components::stage::dummy::{Accumulators, DummyStageConfig, KeyValues, Queries}; +use crate::components::stage::dummy::{Accumulators, DummyStageConfig, KeyValues, QueryPartition}; use crate::components::stage::{StageAttention, StageAttentionConfig}; use crate::components::tile::RowWise; +use crate::components::tile::RunningState; use crate::components::tile::TileAttention; +use crate::components::tile::{MaskTile, MaskTileExpand}; +use crate::components::tile::{QueryTile, QueryTileExpand}; use crate::components::{AttentionPrecision, global::GlobalAttentionConfig}; -use crate::components::{StageMask, tile::RunningState}; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; pub struct DummyStageAttention> { _phantom: PhantomData<(AP, SK, SV, SO, TA)>, @@ -36,36 +42,33 @@ impl< type ValueStage = SV; type OutStage = SO; - type QueryPartition = Queries; - type KeyValuePartition = KeyValues; - type SoftmaxPartition = SoftmaxPartition; - type AccumulatorPartition = Accumulators; + type QueryRegisters = QueryPartition; + type KeyValueRegisters = KeyValues; + type SoftmaxRegisters = SoftmaxPartition; + type AccumulatorRegisters = Accumulators; + type MaskRegisters = MaskPartition; fn execute( - key_reader: &Self::KeyStage, - value_reader: &Self::ValueStage, - query_partition: &Self::QueryPartition, - key_value_partition: &mut Self::KeyValuePartition, - softmax_partition: &mut Self::SoftmaxPartition, - mask: StageMask, - accumulator_partition: &mut Self::AccumulatorPartition, + query_partition: &Self::QueryRegisters, + key_stage: &Self::KeyStage, + value_stage: &Self::ValueStage, + key_value_partition: &mut Self::KeyValueRegisters, + mask_partition: &Self::MaskRegisters, + softmax_partition: &mut Self::SoftmaxRegisters, + accumulator_partition: &mut Self::AccumulatorRegisters, state: &mut Sequence>>, #[comptime] config: Self::Config, ) { - let partition_mask = mask.to_partition(UNIT_POS_Y); - let p = config.tiling_scheme().partition_size; let mut max_placeholder = TA::init_max_placeholder(config.num_rows_per_unit()); let mut sum_placeholder = TA::init_sum_placeholder(config.num_rows_per_unit()); #[unroll] - #[allow(clippy::explicit_counter_loop)] for kv in 0..p.seq_kv { #[unroll] - #[allow(clippy::explicit_counter_loop)] for hd in 0..p.head_dim { - let key_tile = SK::tile(key_reader, (hd, kv).runtime()); + let key_tile = SK::tile(key_stage, (hd, kv).runtime()); TA::fill_key( &key_tile, @@ -77,13 +80,13 @@ impl< let mut scales = Sequence::>>::new(); #[unroll] - #[allow(clippy::explicit_counter_loop)] for q in 0..p.seq_q { let softmax_tile = softmax_partition.get_at_mut(q, kv, config); TA::zero_softmax(softmax_tile, config.tile_config()); + let mask_tile = mask_partition.get_at(q, kv, config.tiling_scheme()); + #[unroll] - #[allow(clippy::explicit_counter_loop)] for hd in 0..p.head_dim { let query_tile = query_partition.get_at(q, hd, config); let key_tile = key_value_partition.get_key_at(hd, kv, config); @@ -95,7 +98,7 @@ impl< scales.push(TA::softmax( softmax_tile, - partition_mask.to_tile(q, kv), + mask_tile, state_q, &mut max_placeholder, &mut sum_placeholder, @@ -105,9 +108,8 @@ impl< } #[unroll] - #[allow(clippy::explicit_counter_loop)] for vd in 0..p.val_dim { - let value_tile = SV::tile(value_reader, (kv, vd).runtime()); + let value_tile = SV::tile(value_stage, (kv, vd).runtime()); TA::fill_value( &value_tile, @@ -117,12 +119,10 @@ impl< } #[unroll] - #[allow(clippy::explicit_counter_loop)] for q in 0..p.seq_q { let softmax_tile = softmax_partition.get_at(q, kv, config); #[unroll] - #[allow(clippy::explicit_counter_loop)] for vd in 0..p.val_dim { TA::accumulate_value( softmax_tile, @@ -137,32 +137,22 @@ impl< } fn rescale( - acc: &mut Self::AccumulatorPartition, + acc: &mut Self::AccumulatorRegisters, state: Sequence>>, #[comptime] config: Self::Config, ) { let p = config.tiling_scheme().partition_size; - let mut q = comptime!(0u32); - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.seq_q { - let mut vd = comptime!(0u32); - + for q in 0..p.seq_q { #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.val_dim { + for vd in 0..p.val_dim { TA::rescale( - Self::AccumulatorPartition::get_at_mut(acc, q, vd, config), + Self::AccumulatorRegisters::get_at_mut(acc, q, vd, config), state.index(q), config.tile_config(), ); - - comptime![vd += 1]; } - - comptime![q += 1]; } } @@ -179,58 +169,93 @@ impl< } fn write( - acc: &Self::AccumulatorPartition, + acc: &Self::AccumulatorRegisters, stage: &mut Self::OutStage, writer: &mut W, #[comptime] stage_config: Self::Config, ) { let p = stage_config.tiling_scheme().partition_size; - let mut q = comptime!(0u32); W::on_event(writer, WriteEvent::new_Begin()); #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.seq_q { - let mut kv = comptime!(0u32); - + for q in 0..p.seq_q { #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.val_dim { - let tile_pos = (q + UNIT_POS_Y * p.seq_q, kv.runtime()); + for vd in 0..p.val_dim { + let tile_pos = (q + UNIT_POS_Y * p.seq_q, vd.runtime()); let mut tile = Self::OutStage::tile(stage, tile_pos); TA::write_results( &mut tile, - Self::AccumulatorPartition::get_at(acc, q, kv, stage_config), + Self::AccumulatorRegisters::get_at(acc, q, vd, stage_config), stage_config.tile_config(), ); W::on_event(writer, WriteEvent::new_TileStored(tile_pos)); - - comptime![kv += 1]; } - - comptime![q += 1]; } W::on_event(writer, WriteEvent::new_Finish()); } - fn init_partitions( - query_loader: QueryReader, + fn init_query(#[comptime] config: Self::Config) -> Self::QueryRegisters { + Self::QueryRegisters::new(config) + } + + fn init_key_value(#[comptime] config: Self::Config) -> Self::KeyValueRegisters { + Self::KeyValueRegisters::new(config) + } + + fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxRegisters { + Self::SoftmaxRegisters::new(config) + } + + fn init_accumulator(#[comptime] config: Self::Config) -> Self::AccumulatorRegisters { + Self::AccumulatorRegisters::new(config) + } + + fn init_mask( + out_of_bounds: CubeOption, + #[comptime] config: Self::Config, + ) -> Self::MaskRegisters { + Self::MaskRegisters::new(out_of_bounds, config) + } + + fn read_query( + reader: &QueryReader, + registers: &mut Self::QueryRegisters, + #[comptime] config: Self::Config, + ) { + let p = config.tiling_scheme().partition_size; + + #[unroll] + for q in 0..p.seq_q { + #[unroll] + for hd in 0..p.head_dim { + let tile_to_write = registers.get_at_mut(q, hd, config); + let tile_read = reader.get_tile::((q, hd).runtime(), config); + + tile_to_write.update(tile_read); + } + } + } + + fn read_mask( + reader: &MaskReader, + registers: &mut Self::MaskRegisters, #[comptime] config: Self::Config, - ) -> ( - Self::QueryPartition, - Self::KeyValuePartition, - Self::SoftmaxPartition, - Self::AccumulatorPartition, ) { - ( - Self::QueryPartition::new(query_loader, config), - Self::KeyValuePartition::new(config), - Self::SoftmaxPartition::new(config), - Self::AccumulatorPartition::new(config), - ) + let p = config.tiling_scheme().partition_size; + + #[unroll] + for q in 0..p.seq_q { + #[unroll] + for kv in 0..p.seq_kv { + let mask_tile = registers.get_at_mut(q, kv, config.tiling_scheme()); + + let (new_origin, tile) = reader.read::((q, kv), config); + mask_tile.update(new_origin, tile); + } + } } } diff --git a/crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs b/crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs index 45aaf82a1..c698fbaab 100644 --- a/crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs +++ b/crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs @@ -4,8 +4,10 @@ use std::marker::PhantomData; use cubecl::prelude::*; use cubecl_core as cubecl; -use crate::components::global::dummy::QueryReader; +use crate::components::AttentionTilingScheme; use crate::components::{AttentionPrecision, stage::StageAttentionConfig, tile::TileAttention}; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; #[derive(CubeType)] pub struct Accumulators< @@ -62,7 +64,7 @@ impl< } #[derive(CubeType)] -pub struct Queries< +pub struct QueryPartition< AP: AttentionPrecision, TA: TileAttention, S: StageAttentionConfig, @@ -77,32 +79,18 @@ impl< AP: AttentionPrecision, TA: TileAttention, S: StageAttentionConfig, -> Queries +> QueryPartition { - pub fn new(query_loader: QueryReader, #[comptime] config: S) -> Queries { + pub fn new(#[comptime] config: S) -> QueryPartition { let p = config.tiling_scheme().partition_size; let mut sequence = Sequence::new(); - let mut q = comptime!(0u32); - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..comptime!(p.seq_q) { - let mut hd = comptime!(0u32); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..comptime!(p.head_dim) { - let tile = query_loader.get_tile::((q, hd).runtime(), config); - sequence.push(TA::init_query(&tile, config.tile_config())); - - comptime![hd += 1]; - } - - comptime![q += 1]; + for _ in 0..comptime!(p.seq_q * p.head_dim) { + sequence.push(TA::init_query(config.tile_config())); } - Queries:: { + QueryPartition:: { sequence, _phantom: PhantomData, } @@ -304,3 +292,71 @@ impl< self.sequence.index_mut(index) } } + +#[derive(CubeType)] +pub struct MaskPartition< + AP: AttentionPrecision, + TA: TileAttention, + S: StageAttentionConfig, +> { + sequence: Sequence, + #[cube(comptime)] + _phantom: PhantomData, +} + +#[cube] +impl< + AP: AttentionPrecision, + TA: TileAttention, + S: StageAttentionConfig, +> MaskPartition +{ + pub fn new( + out_of_bounds: CubeOption, + #[comptime] config: S, + ) -> MaskPartition { + let p = config.tiling_scheme().partition_size; + let mut sequence = Sequence::new(); + + let mut q = comptime![0]; + + #[unroll] + for _ in 0..p.seq_q { + let mut kv = comptime![0]; + + #[unroll] + for _ in 0..p.seq_kv { + sequence.push(TA::init_mask(out_of_bounds, (q, kv), config.tile_config())); + + comptime![kv += 1]; + } + + comptime![q += 1]; + } + + MaskPartition:: { + sequence, + _phantom: PhantomData, + } + } + + pub fn get_at( + &self, + #[comptime] q: u32, + #[comptime] kv: u32, + #[comptime] tiling_scheme: AttentionTilingScheme, + ) -> &TA::MaskTile { + let p = tiling_scheme.partition_size; + self.sequence.index(comptime!(q * p.seq_kv + kv)) + } + + pub fn get_at_mut( + &mut self, + #[comptime] q: u32, + #[comptime] kv: u32, + #[comptime] tiling_scheme: AttentionTilingScheme, + ) -> &mut TA::MaskTile { + let p = tiling_scheme.partition_size; + self.sequence.index_mut(comptime!(q * p.seq_kv + kv)) + } +} diff --git a/crates/cubecl-attention/src/components/tile/base.rs b/crates/cubecl-attention/src/components/tile/base.rs index aaabb4672..67054b05f 100644 --- a/crates/cubecl-attention/src/components/tile/base.rs +++ b/crates/cubecl-attention/src/components/tile/base.rs @@ -6,6 +6,8 @@ use cubecl_matmul::components::{ tile::StridedTile, }; +use crate::components::tile::MaskTile; +use crate::components::tile::SoftmaxTile; use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, AttentionSetupError, AvailableLineSizes, @@ -13,7 +15,8 @@ use crate::components::{ tile::{KeyValueTile, QueryTile, RowWise, RunningState, dummy::AttentionMatmulConfig}, }; use crate::components::{InvalidConfigError, tile::AccumulatorTile}; -use crate::components::{TileMask, tile::SoftmaxTile}; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; pub type AttentionTilingLayout = ContiguousTilingLayout; @@ -51,10 +54,11 @@ pub trait TileAttention: 'static + Send + Sync { /// The configuration type associated with this Attention. type Config: AttentionMatmulConfig; - type QueryTile: QueryTile>; + type QueryTile: QueryTile; type KeyValueTile: KeyValueTile>; type SoftmaxTile: SoftmaxTile; type AccumulatorTile: AccumulatorTile; + type MaskTile: MaskTile>; fn rescale( acc: &mut Self::AccumulatorTile, @@ -70,25 +74,38 @@ pub trait TileAttention: 'static + Send + Sync { fn init_accumulator(#[comptime] config: Self::Config) -> Self::AccumulatorTile; - fn init_query(tile: &StridedTile>, #[comptime] config: Self::Config) -> Self::QueryTile; + fn init_query(#[comptime] config: Self::Config) -> Self::QueryTile; fn init_key_value(#[comptime] config: Self::Config) -> Self::KeyValueTile; fn init_key(#[comptime] config: Self::Config) -> Self::KeyValueTile; fn init_value(#[comptime] config: Self::Config) -> Self::KeyValueTile; + fn init_mask( + out_of_bounds: CubeOption, + #[comptime] partition_pos: Coords2d, + #[comptime] config: Self::Config, + ) -> Self::MaskTile; fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxTile; fn init_state(#[comptime] config: Self::Config) -> RunningState>; + fn fill_query(tile: &StridedTile, registers: &mut Self::QueryTile); + fn fill_key( tile: &StridedTile, - rhs: &mut Self::KeyValueTile, + registers: &mut Self::KeyValueTile, #[comptime] config: Self::Config, ); fn fill_value( tile: &StridedTile, - rhs: &mut Self::KeyValueTile, + registers: &mut Self::KeyValueTile, + #[comptime] config: Self::Config, + ); + + fn fill_mask( + tile: &StridedTile, + registers: &mut Self::MaskTile, #[comptime] config: Self::Config, ); @@ -103,7 +120,7 @@ pub trait TileAttention: 'static + Send + Sync { fn softmax( softmax: &mut Self::SoftmaxTile, - mask: TileMask, + mask: &Self::MaskTile, state: &mut RunningState>, max_placeholder: &mut RowWise>, sum_placeholder: &mut RowWise>, diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention.rs b/crates/cubecl-attention/src/components/tile/dummy/attention.rs index 1a6e1da60..6f1f8ec45 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/attention.rs @@ -3,20 +3,24 @@ use cubecl_core::prelude::*; use cubecl_matmul::components::tile::StridedTile; use std::marker::PhantomData; -use crate::components::TileMask; use crate::components::attention_types::*; use crate::components::tile::AccumulatorTile as _; use crate::components::tile::AccumulatorTileExpand; use crate::components::tile::SoftmaxTileExpand; use crate::components::tile::dummy::DummyAccumulator; +use crate::components::tile::dummy::MaskFragment; use crate::components::tile::dummy::attention_matmul::AttentionMatmulConfig; use crate::components::tile::dummy::{AttentionMatmul, DummySoftmax}; use crate::components::tile::tiles::{KeyValueTile, KeyValueTileExpand}; +use crate::components::tile::tiles::{MaskTile, MaskTileExpand}; +use crate::components::tile::tiles::{QueryTile, QueryTileExpand}; use crate::components::tile::{RowWise, RunningState, SoftmaxTile, TileAttention}; use crate::components::{ AttentionPrecision, tile::dummy::{KeyValueFragment, QueryFragment}, }; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; pub struct DummyTileAttention> { _phantom: PhantomData<(AP, AM)>, @@ -32,6 +36,7 @@ impl> TileAttention type KeyValueTile = KeyValueFragment; type SoftmaxTile = DummySoftmax; type AccumulatorTile = DummyAccumulator; + type MaskTile = MaskFragment; fn rescale( acc: &mut Self::AccumulatorTile, @@ -53,8 +58,8 @@ impl> TileAttention Self::AccumulatorTile::new(config) } - fn init_query(tile: &StridedTile>, #[comptime] config: Self::Config) -> Self::QueryTile { - Self::QueryTile::new(tile, config) + fn init_query(#[comptime] config: Self::Config) -> Self::QueryTile { + Self::QueryTile::new(config) } fn init_key_value(#[comptime] config: Self::Config) -> Self::KeyValueTile { @@ -69,6 +74,14 @@ impl> TileAttention Self::KeyValueTile::new_value(config) } + fn init_mask( + out_of_bounds: CubeOption, + #[comptime] partition_pos: Coords2d, + #[comptime] config: Self::Config, + ) -> Self::MaskTile { + Self::MaskTile::new(out_of_bounds, partition_pos, config) + } + fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxTile { Self::SoftmaxTile::new(config) } @@ -77,20 +90,32 @@ impl> TileAttention RunningState::>::init(config.num_rows_per_unit()) } + fn fill_query(tile: &StridedTile, registers: &mut Self::QueryTile) { + AM::fill_query(tile, registers.fragment_mut()); + } + fn fill_key( tile: &StridedTile, - rhs: &mut Self::KeyValueTile, + registers: &mut Self::KeyValueTile, #[comptime] config: Self::Config, ) { - AM::fill_key_value(tile, rhs.key_mut(), config); + AM::fill_key_value(tile, registers.key_mut(), config); } fn fill_value( tile: &StridedTile, - rhs: &mut Self::KeyValueTile, + registers: &mut Self::KeyValueTile, #[comptime] config: Self::Config, ) { - AM::fill_key_value(tile, rhs.value_mut(), config); + AM::fill_key_value(tile, registers.value_mut(), config); + } + + fn fill_mask( + tile: &StridedTile, + mask: &mut Self::MaskTile, + #[comptime] config: Self::Config, + ) { + AM::fill_mask(tile, mask.fragment_mut(), config) } fn zero_softmax(score: &mut Self::SoftmaxTile, #[comptime] config: Self::Config) { @@ -113,14 +138,18 @@ impl> TileAttention fn softmax( softmax: &mut Self::SoftmaxTile, - mask: TileMask, + mask: &Self::MaskTile, state: &mut RunningState>, max_placeholder: &mut RowWise>, sum_placeholder: &mut RowWise>, #[comptime] dk: u32, #[comptime] config: Self::Config, ) -> RowWise> { - softmax.scale_and_mask(SM::::new(comptime!(1.0 / (dk as f32).sqrt())), mask); + Self::SoftmaxTile::scale_and_mask::( + softmax, + SM::::new(comptime!(1.0 / (dk as f32).sqrt())), + mask, + ); softmax.row_max::(max_placeholder, &state.m, config); diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs index 95114f9da..9d9354242 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs @@ -53,6 +53,14 @@ impl AttentionMatmulConfig for AcceleratedAttentionMatmulConfig { fn num_rows_per_unit(&self) -> u32 { todo!() } + + fn causal_mask(&self) -> bool { + todo!() + } + + fn materialized_mask(&self) -> bool { + todo!() + } } impl AcceleratedAttentionMatmulConfig { diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs index 210f8794c..5db0b2992 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs @@ -3,29 +3,35 @@ use cubecl_core::{cmma, prelude::*}; use cubecl_matmul::components::tile::StridedTile; use crate::components::AttentionPrecision; -use crate::components::TileMask; use crate::components::attention_types::*; +use crate::components::tile::MaskTile; use crate::components::tile::RowWise; use crate::components::tile::dummy::accelerated::AcceleratedAttentionMatmulConfig; use crate::components::tile::dummy::{AttentionMatmul, AttentionMatmulConfig as _}; -use crate::components::tile::{PlaneLayout, PlaneLayoutExpand}; +use crate::components::tile::{FragmentLayout, FragmentLayoutExpand}; +use crate::components::tile::{FragmentMask, FragmentMaskExpand}; +use crate::components::tile::{FragmentOps, FragmentOpsExpand}; +use cubecl_std::tensor::layout::Coords2d; /// Performs two matmuls with fragment reuse for key/value and score/prob pub struct AcceleratedAttentionMatmul; -#[cube] -impl PlaneLayout for cmma::Matrix { - fn num_local_rows(&self) -> comptime_type!(u32) { - todo!() - } +#[derive(CubeType)] +pub struct TODO; - fn num_local_cols(&self) -> comptime_type!(u32) { +#[cube] +impl FragmentLayout for TODO { + fn absolute_pos(&self, _local_pos: Coords2d) -> Coords2d { todo!() } - fn num_units_per_row(&self) -> comptime_type!(u32) { todo!() } +} + +#[cube] +impl FragmentOps for cmma::Matrix { + type Layout = TODO; fn rowwise_max(&self) -> RowWise { todo!() @@ -39,13 +45,24 @@ impl PlaneLayout for cmma::Matrix { todo!() } - fn scale_and_mask(&mut self, _scale: E, _mask: TileMask) { + fn scale_and_mask(_this: &mut Self, _scale: E, _mask: &M) { todo!() } fn exp_m_diff(&mut self, _val: &RowWise) { todo!() } + + fn layout(&self) -> Self::Layout { + todo!() + } +} + +#[cube] +impl FragmentMask for cmma::Matrix { + fn should_mask(&self, _local_pos: Coords2d) -> bool { + todo!() + } } #[cube] @@ -53,8 +70,10 @@ impl AttentionMatmul for AcceleratedAttentionMatmul type Config = AcceleratedAttentionMatmulConfig; type Query = cmma::Matrix>; type KeyValue = cmma::Matrix>; + type Mask = cmma::Matrix>; type Softmax = cmma::Matrix>; type Accumulator = cmma::Matrix>; + type FragmentLayout = TODO; fn score_matmul( lhs: &Self::Query, @@ -74,42 +93,26 @@ impl AttentionMatmul for AcceleratedAttentionMatmul cmma::execute::, KVT, ACC, ACC>(lhs, rhs, out, out); } - fn allocate_fill_query( - tile: &StridedTile, - #[comptime] config: Self::Config, - ) -> Self::Query { - let (slice, stride) = tile.as_unlined(); + fn allocate_query(#[comptime] config: Self::Config) -> Self::Query { let size = config.attention_tile_size().to_score_matmul_tile_size(); - if config.cast_query() { - let query = unsafe { - cmma::Matrix::>::uninitialized( - cmma::MatrixIdent::A, - size.m(), - size.n(), - size.k(), - cmma::MatrixLayout::RowMajor, - ) - }; - - cmma::load(&query, &slice, stride); - query - } else { - let tmp = unsafe { - cmma::Matrix::::uninitialized( - cmma::MatrixIdent::A, - size.m(), - size.n(), - size.k(), - cmma::MatrixLayout::RowMajor, - ) - }; - - cmma::load(&tmp, &slice, stride); - cmma::cast::>(&tmp) + unsafe { + cmma::Matrix::>::uninitialized( + cmma::MatrixIdent::A, + size.m(), + size.n(), + size.k(), + cmma::MatrixLayout::RowMajor, + ) } } + fn fill_query(tile: &StridedTile, fragment: &mut Self::Query) { + let (slice, stride) = tile.as_unlined(); + + cmma::load(fragment, &slice, stride); + } + fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue { let size = config.attention_tile_size(); unsafe { @@ -153,6 +156,19 @@ impl AttentionMatmul for AcceleratedAttentionMatmul } } + fn allocate_mask(#[comptime] config: Self::Config) -> Self::Mask { + let size = config.attention_tile_size(); + unsafe { + cmma::Matrix::>::uninitialized( + cmma::MatrixIdent::Accumulator, + size.seq_q, + size.seq_kv, + size.head_dim, + cmma::MatrixLayout::RowMajor, + ) + } + } + fn fill_key_value( tile: &StridedTile, rhs: &mut Self::KeyValue, @@ -162,6 +178,14 @@ impl AttentionMatmul for AcceleratedAttentionMatmul cmma::load(rhs, &slice, stride); } + fn fill_mask( + _tile: &StridedTile, + _mask: &mut Self::Mask, + #[comptime] _config: Self::Config, + ) { + todo!() + } + fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax { let size = config.attention_tile_size(); unsafe { @@ -209,4 +233,8 @@ impl AttentionMatmul for AcceleratedAttentionMatmul cmma::MatrixLayout::RowMajor, ); } + + fn softmax_layout(#[comptime] _config: Self::Config) -> Self::FragmentLayout { + todo!() + } } diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/base.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/base.rs index 90a775e48..f5bf76a31 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/base.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/base.rs @@ -4,7 +4,9 @@ use cubecl_matmul::components::ComputeResources; use cubecl_matmul::components::tile::StridedTile; use crate::components::attention_types::*; -use crate::components::tile::PlaneLayout; +use crate::components::tile::FragmentLayout; +use crate::components::tile::FragmentMask; +use crate::components::tile::FragmentOps; use crate::components::{ AttentionIdent, AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, AttentionSetupError, AttentionTileSize, AvailableLineSizes, InvalidConfigError, @@ -17,8 +19,10 @@ pub trait AttentionMatmul: Send + Sync + 'static { type Config: AttentionMatmulConfig; type Query: CubeType; type KeyValue: CubeType; - type Softmax: PlaneLayout>; - type Accumulator: PlaneLayout>; + type Mask: FragmentMask; + type Softmax: FragmentOps, Layout = Self::FragmentLayout>; + type Accumulator: FragmentOps, Layout = Self::FragmentLayout>; + type FragmentLayout: FragmentLayout; fn score_matmul( lhs: &Self::Query, @@ -34,14 +38,10 @@ pub trait AttentionMatmul: Send + Sync + 'static { #[comptime] config: Self::Config, ); - fn allocate_fill_query( - tile: &StridedTile, - #[comptime] config: Self::Config, - ) -> Self::Query; - fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue; fn allocate_value(#[comptime] config: Self::Config) -> Self::KeyValue; fn allocate_key_value(#[comptime] config: Self::Config) -> Self::KeyValue; + fn allocate_mask(#[comptime] config: Self::Config) -> Self::Mask; fn fill_key_value( tile: &StridedTile, @@ -49,6 +49,15 @@ pub trait AttentionMatmul: Send + Sync + 'static { #[comptime] config: Self::Config, ); + fn fill_mask( + tile: &StridedTile, + fragment: &mut Self::Mask, + #[comptime] config: Self::Config, + ); + + fn allocate_query(#[comptime] config: Self::Config) -> Self::Query; + fn fill_query(tile: &StridedTile, fragment: &mut Self::Query); + fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax; fn zero_softmax(softmax: &mut Self::Softmax, #[comptime] config: Self::Config); @@ -60,6 +69,8 @@ pub trait AttentionMatmul: Send + Sync + 'static { slice: &mut SliceMut>, #[comptime] config: Self::Config, ); + + fn softmax_layout(#[comptime] config: Self::Config) -> Self::FragmentLayout; } /// Configuration for the Tile Attention level @@ -78,6 +89,9 @@ pub trait AttentionMatmulConfig: fn check_bounds(&self) -> bool; fn num_rows_per_unit(&self) -> u32; + + fn causal_mask(&self) -> bool; + fn materialized_mask(&self) -> bool; } pub trait AttentionMatmulFamily: Send + Sync + 'static { diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs index 2e35dacbd..bcc5f4027 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs @@ -19,6 +19,8 @@ pub struct DummyRegisterAttentionMatmulConfig { cast_query: bool, check_bounds: bool, inner_layout: InnerLayout, + causal_mask: bool, + materialized_mask: bool, } impl AttentionMatmulConfig for DummyRegisterAttentionMatmulConfig { @@ -59,9 +61,18 @@ impl AttentionMatmulConfig for DummyRegisterAttentionMatmulConfig { InnerLayout::SplitRows => 2u32, } } + + fn causal_mask(&self) -> bool { + self.causal_mask + } + + fn materialized_mask(&self) -> bool { + self.materialized_mask + } } impl DummyRegisterAttentionMatmulConfig { + #[allow(clippy::too_many_arguments)] pub fn new( plane_dim: u32, attention_tile_size: AttentionTileSize, @@ -70,6 +81,8 @@ impl DummyRegisterAttentionMatmulConfig { key_value_stage_line_size: u32, check_bounds: bool, two_rows_in_array_tile: bool, + causal_mask: bool, + materialized_mask: bool, ) -> Result { Self { plane_dim, @@ -85,6 +98,8 @@ impl DummyRegisterAttentionMatmulConfig { } else { InnerLayout::Contiguous }, + causal_mask, + materialized_mask, } .validate() } diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs index 0a8a6f494..25cdd07d0 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs @@ -7,12 +7,14 @@ use cubecl_std::tensor::layout::Coords2d; use crate::components::AttentionPrecision; use crate::components::attention_types::*; +use crate::components::tile::MaskTile; +use crate::components::tile::{FragmentMask, FragmentMaskExpand}; use crate::components::tile::{RowVal, RowWise}; -use crate::components::TileMask; use crate::components::tile::dummy::dummy_register::DummyRegisterAttentionMatmulConfig; use crate::components::tile::dummy::{AttentionMatmul, AttentionMatmulConfig as _}; -use crate::components::tile::{PlaneLayout, PlaneLayoutExpand}; +use crate::components::tile::{FragmentLayout, FragmentLayoutExpand}; +use crate::components::tile::{FragmentOps, FragmentOpsExpand}; pub struct DummyRegisterAttentionMatmul; @@ -21,16 +23,9 @@ pub struct DummyRegisterAttentionMatmul; /// - All elements of a unit are contiguous /// - unit_size * plane_dim = total_size (not dim wise but in total count) /// - There is never more than one row for one unit -pub struct ArrayTile { +pub struct ArrayTile { array: Array, - #[cube(comptime)] - total_size: Coords2d, - #[cube(comptime)] - unit_size: Coords2d, - #[cube(comptime)] - num_units_per_row: u32, - #[cube(comptime)] - plane_dim: u32, + layout: ArrayTileLayout, } #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] @@ -62,12 +57,38 @@ pub enum InnerLayout { } #[cube] -impl ArrayTile { +impl ArrayTile { + pub fn new(layout: ArrayTileLayout) -> ArrayTile { + let array = Array::::new(comptime!(layout.unit_size.0 * layout.unit_size.1)); + ArrayTile:: { array, layout } + } + + pub fn zero(&mut self) { + for i in 0..self.layout.unit_size.0 * self.layout.unit_size.1 { + self.array[i] = E::from_int(0); + } + } +} + +#[derive(CubeType, Copy, Clone)] +pub struct ArrayTileLayout { + #[cube(comptime)] + total_size: Coords2d, + #[cube(comptime)] + unit_size: Coords2d, + #[cube(comptime)] + num_units_per_row: u32, + #[cube(comptime)] + plane_dim: u32, +} + +#[cube] +impl ArrayTileLayout { pub fn new( #[comptime] total_size: Coords2d, #[comptime] plane_dim: u32, #[comptime] inner_layout: InnerLayout, - ) -> ArrayTile { + ) -> ArrayTileLayout { let total_elements = total_size.0 * total_size.1; let elements_per_unit = total_elements.div_ceil(plane_dim); @@ -77,67 +98,51 @@ impl ArrayTile { }; let unit_size = (num_rows_per_unit, num_cols_per_unit); - let array = Array::::new(comptime!(unit_size.0 * unit_size.1)); let num_units_per_row = comptime!(total_size.1 / unit_size.1); - ArrayTile:: { - array, + ArrayTileLayout { total_size, unit_size, num_units_per_row, plane_dim, } } +} - pub fn zero(&mut self) { - for i in 0..self.unit_size.0 * self.unit_size.1 { - self.array[i] = E::from_int(0); - } - } +#[cube] +impl FragmentLayout for ArrayTileLayout { + fn absolute_pos(&self, local_pos: Coords2d) -> Coords2d { + let abs_row_index = { + let row_0 = UNIT_POS_X / self.num_units_per_row; + let row_jump = comptime!(self.plane_dim / self.num_units_per_row); - fn abs_row_index(&self, r: u32) -> u32 { - let row_0 = UNIT_POS_X / self.num_units_per_row; - let row_jump = comptime!(self.plane_dim / self.num_units_per_row); + local_pos.0 * row_jump + row_0 + }; - r * row_jump + row_0 - } + let abs_col_index = self.unit_size.1 * (UNIT_POS_X % self.num_units_per_row) + local_pos.1; - fn abs_col_index(&self, c: u32) -> u32 { - self.unit_size.1 * (UNIT_POS_X % self.num_units_per_row) + c + (abs_row_index, abs_col_index) } - fn abs_pos(&self, #[comptime] local_pos: Coords2d) -> Coords2d { - ( - self.abs_row_index(local_pos.0), - self.abs_col_index(local_pos.1), - ) + fn num_units_per_row(&self) -> comptime_type!(u32) { + comptime!(self.total_size.1 / self.unit_size.1) } } #[cube] -impl PlaneLayout for ArrayTile { - fn num_local_rows(&self) -> comptime_type!(u32) { - self.unit_size.0 - } - - fn num_local_cols(&self) -> comptime_type!(u32) { - self.unit_size.1 - } - - fn num_units_per_row(&self) -> comptime_type!(u32) { - comptime!(self.total_size.1 / self.unit_size.1) - } +impl FragmentOps for ArrayTile { + type Layout = ArrayTileLayout; fn rowwise_max(&self) -> RowWise { let mut vals = Sequence::new(); #[unroll] - for r in 0..self.unit_size.0 { - let row_offset = r * self.unit_size.1; + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; let mut val = E::min_value(); #[unroll] - for c in 0..self.unit_size.1 { + for c in 0..self.layout.unit_size.1 { let index = row_offset + c; val = Max::max(val, self.array[index]); } @@ -146,7 +151,7 @@ impl PlaneLayout for ArrayTile { } RowWise:: { - num_rows: self.unit_size.0, + num_rows: self.layout.unit_size.0, vals, } } @@ -155,12 +160,12 @@ impl PlaneLayout for ArrayTile { let mut vals = Sequence::new(); #[unroll] - for r in 0..self.unit_size.0 { - let row_offset = r * self.unit_size.1; + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; let mut val = E::from_int(0); #[unroll] - for c in 0..self.unit_size.1 { + for c in 0..self.layout.unit_size.1 { let index = row_offset + c; val += self.array[index]; } @@ -169,55 +174,66 @@ impl PlaneLayout for ArrayTile { } RowWise:: { - num_rows: self.unit_size.0, + num_rows: self.layout.unit_size.0, vals, } } fn scale(&mut self, scale: &RowWise) { #[unroll] - for r in 0..self.unit_size.0 { - let row_offset = r * self.unit_size.1; + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; #[unroll] - for c in 0..self.unit_size.1 { + for c in 0..self.layout.unit_size.1 { let index = row_offset + c; self.array[index] = self.array[index] * scale.index(r); } } } - fn scale_and_mask(&mut self, scale: E, mask: TileMask) { + fn scale_and_mask(this: &mut Self, scale: E, mask: &M) { #[unroll] - for r in 0..self.unit_size.0 { - let row_offset = r * self.unit_size.1; + for r in 0..this.layout.unit_size.0 { + let row_offset = r * this.layout.unit_size.1; #[unroll] - for c in 0..self.unit_size.1 { + for c in 0..this.layout.unit_size.1 { let index = row_offset + c; - self.array[index] = - self.array[index] * scale + mask.apply::(self.abs_pos((r, c))); + this.array[index] = + this.array[index] * scale + M::apply::(mask, (r, c).runtime()); } } } fn exp_m_diff(&mut self, val: &RowWise) { #[unroll] - for r in 0..self.unit_size.0 { - let row_offset = r * self.unit_size.1; + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; #[unroll] - for c in 0..self.unit_size.1 { + for c in 0..self.layout.unit_size.1 { let index = row_offset + c; self.array[index] = Exp::exp(self.array[index] - val.index(r)); } } } + + fn layout(&self) -> Self::Layout { + self.layout + } } #[cube] -fn array_tile_to_tmp_smem( +impl FragmentMask for ArrayTile { + fn should_mask(&self, local_pos: Coords2d) -> bool { + bool::cast_from(self.array[local_pos.0 * self.layout.unit_size.1 + local_pos.1]) + } +} + +#[cube] +fn array_tile_to_tmp_smem( array_tile: &ArrayTile, #[comptime] num_planes: u32, ) -> SliceMut { - let tile_size = comptime!(array_tile.total_size.0 * array_tile.total_size.1); + let tile_size = comptime!(array_tile.layout.total_size.0 * array_tile.layout.total_size.1); let mut tmp_smem = SharedMemory::::new(comptime!(num_planes * tile_size)); let start = UNIT_POS_Y * tile_size; @@ -231,11 +247,11 @@ fn array_tile_to_tmp_smem( } sync_cube(); - for r in 0..array_tile.unit_size.0 { - for c in 0..array_tile.unit_size.1 { - let index = - array_tile.abs_row_index(r) * array_tile.total_size.1 + array_tile.abs_col_index(c); - tmp_smem_slice[index] = array_tile.array[r * array_tile.unit_size.1 + c]; + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + let index = row * array_tile.layout.total_size.1 + col; + tmp_smem_slice[index] = array_tile.array[r * array_tile.layout.unit_size.1 + c]; } } @@ -243,40 +259,40 @@ fn array_tile_to_tmp_smem( } #[cube] -fn tmp_smem_to_array_tile(tmp_smem_slice: &SliceMut, array_tile: &mut ArrayTile) { - for r in 0..array_tile.unit_size.0 { - for c in 0..array_tile.unit_size.1 { - array_tile.array[r * array_tile.unit_size.1 + c] = - tmp_smem_slice[array_tile.abs_row_index(r) * array_tile.total_size.1 - + array_tile.abs_col_index(c)]; +fn tmp_smem_to_array_tile(tmp_smem_slice: &SliceMut, array_tile: &mut ArrayTile) { + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + let index = row * array_tile.layout.total_size.1 + col; + array_tile.array[r * array_tile.layout.unit_size.1 + c] = tmp_smem_slice[index]; } } } #[cube] -fn strided_tile_to_array_tile( +fn strided_tile_to_array_tile( strided_tile: &StridedTile, array_tile: &mut ArrayTile, ) { - for r in 0..array_tile.unit_size.0 { - for c in 0..array_tile.unit_size.1 { - array_tile.array[r * array_tile.unit_size.1 + c] = E2::cast_from( - strided_tile.get_line(array_tile.abs_row_index(r), array_tile.abs_col_index(c)), - ) + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + array_tile.array[r * array_tile.layout.unit_size.1 + c] = + E2::cast_from(strided_tile.get_line(row, col)) } } } #[cube] -fn array_tile_to_slice( +fn array_tile_to_slice( array_tile: &ArrayTile, slice: &mut SliceMut>, ) { - for r in 0..array_tile.unit_size.0 { - for c in 0..array_tile.unit_size.1 { - let index = - array_tile.abs_row_index(r) * array_tile.total_size.1 + array_tile.abs_col_index(c); - slice[index] = Line::cast_from(array_tile.array[r * array_tile.unit_size.1 + c]); + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + let index = row * array_tile.layout.total_size.1 + col; + slice[index] = Line::cast_from(array_tile.array[r * array_tile.layout.unit_size.1 + c]); } } } @@ -287,8 +303,21 @@ impl AttentionMatmul for DummyRegisterAttentionMatmu type Query = ArrayTile>; type KeyValue = ArrayTile>; + type Mask = ArrayTile>; type Softmax = ArrayTile>; type Accumulator = ArrayTile>; + type FragmentLayout = ArrayTileLayout; + + fn softmax_layout(#[comptime] config: Self::Config) -> ArrayTileLayout { + ArrayTileLayout::new( + ( + config.attention_tile_size().seq_q, + config.attention_tile_size().seq_kv, + ), + config.plane_dim(), + config.inner_layout(), + ) + } fn score_matmul( lhs: &Self::Query, @@ -356,7 +385,7 @@ impl AttentionMatmul for DummyRegisterAttentionMatmu } fn allocate_key_value(#[comptime] config: Self::Config) -> Self::KeyValue { - ArrayTile::new( + ArrayTile::new(ArrayTileLayout::new( ( comptime!(max( config.attention_tile_size().head_dim, @@ -369,67 +398,65 @@ impl AttentionMatmul for DummyRegisterAttentionMatmu ), config.plane_dim(), config.inner_layout(), - ) + )) } fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue { - ArrayTile::new( + ArrayTile::new(ArrayTileLayout::new( ( config.attention_tile_size().head_dim, config.attention_tile_size().seq_kv, ), config.plane_dim(), config.inner_layout(), - ) + )) } fn allocate_value(#[comptime] config: Self::Config) -> Self::KeyValue { - ArrayTile::new( + ArrayTile::new(ArrayTileLayout::new( ( config.attention_tile_size().seq_kv, config.attention_tile_size().val_dim, ), config.plane_dim(), config.inner_layout(), - ) + )) + } + + fn allocate_mask(#[comptime] config: Self::Config) -> Self::Mask { + ArrayTile::new(>::softmax_layout(config)) } fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax { - ArrayTile::new( - ( - config.attention_tile_size().seq_q, - config.attention_tile_size().seq_kv, - ), - config.plane_dim(), - config.inner_layout(), - ) + ArrayTile::new(>::softmax_layout(config)) } fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator { - ArrayTile::new( + ArrayTile::new(ArrayTileLayout::new( ( config.attention_tile_size().seq_q, config.attention_tile_size().val_dim, ), config.plane_dim(), config.inner_layout(), - ) + )) } - fn allocate_fill_query( - tile: &StridedTile, - #[comptime] config: Self::Config, - ) -> Self::Query { + fn allocate_query(#[comptime] config: Self::Config) -> Self::Query { let seq_q = config.attention_tile_size().seq_q; let head_dim = config.attention_tile_size().head_dim; - let mut query = - ArrayTile::new((seq_q, head_dim), config.plane_dim(), config.inner_layout()); + ArrayTile::new(ArrayTileLayout::new( + (seq_q, head_dim), + config.plane_dim(), + config.inner_layout(), + )) + } - strided_tile_to_array_tile(tile, &mut query); + fn fill_query(tile: &StridedTile, fragment: &mut Self::Query) { + strided_tile_to_array_tile(tile, fragment); sync_cube(); - query } fn fill_key_value( @@ -442,6 +469,16 @@ impl AttentionMatmul for DummyRegisterAttentionMatmu sync_cube(); } + fn fill_mask( + tile: &StridedTile, + mask: &mut Self::Mask, + #[comptime] _config: Self::Config, + ) { + strided_tile_to_array_tile(tile, mask); + + sync_cube(); + } + fn zero_softmax(softmax: &mut Self::Softmax, #[comptime] _config: Self::Config) { softmax.zero(); sync_cube(); diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/setup.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/setup.rs index 55eaac89f..807f37af6 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/setup.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/setup.rs @@ -38,6 +38,8 @@ impl AttentionMatmulFamily for DummyRegisterAttentionMatmul { line_sizes.key as u32, !(problem.seq_kv as u32).is_multiple_of(selection.tiling_scheme.tile_size.seq_kv), selection.two_rows_in_array_tile, + problem.causal, + problem.masked, ) } } diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs index 0ea14a0c2..960c56e1d 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs @@ -7,7 +7,7 @@ use crate::components::tile::AccumulatorTile; use crate::components::tile::AccumulatorTileExpand; use crate::components::tile::RowWise; use crate::components::tile::dummy::AttentionMatmul; -use crate::components::tile::row::{PlaneLayout, PlaneLayoutExpand}; +use crate::components::tile::row::{FragmentOps, FragmentOpsExpand}; #[derive(CubeType)] pub struct DummyAccumulator> { diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/mask.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/mask.rs new file mode 100644 index 000000000..7ac92b569 --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/dummy/fragment/mask.rs @@ -0,0 +1,172 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_std::tensor::layout::Coords2d; +use cubecl_std::{CubeOption, CubeOptionExpand}; + +use crate::components::AttentionPrecision; +use crate::components::attention_types::MSK; +use crate::components::tile::dummy::AttentionMatmul; +use crate::components::tile::dummy::attention_matmul::AttentionMatmulConfig; +use crate::components::tile::row::{FragmentMask, FragmentMaskExpand}; +use crate::components::tile::{FragmentLayout, FragmentLayoutExpand, MaskTile, MaskTileExpand}; +use cubecl_matmul::components::tile::StridedTile; + +use cubecl_std::tensor::layout::Coordinates; + +#[derive(CubeType)] +pub struct LogicalIterOrigin { + row: RuntimeCell, + col: RuntimeCell, +} + +#[cube] +impl LogicalIterOrigin { + fn dummy() -> LogicalIterOrigin { + LogicalIterOrigin { + row: RuntimeCell::new(0), + col: RuntimeCell::new(0), + } + } + + fn read(&self) -> Coords2d { + (self.row.read(), self.col.read()) + } + + fn update(&mut self, new: Coords2d) { + self.row.store(new.0); + self.col.store(new.1); + } +} + +#[derive(CubeType)] +pub struct LogicalTileMask { + logical_iter_origin: LogicalIterOrigin, + #[cube(comptime)] + partition_pos: Coords2d, + #[cube(comptime)] + causal: bool, + out_of_bounds: CubeOption, + fragment_layout: F, +} + +#[cube] +impl LogicalTileMask { + pub fn should_mask(&self, local_pos: Coords2d) -> bool { + let pos_in_tile = self.fragment_layout.absolute_pos(local_pos); + + let pos = Coords2d::add( + self.logical_iter_origin.read(), + Coords2d::add(self.partition_pos.runtime(), pos_in_tile), + ); + + let causal_masked = self.causal && pos.0 < pos.1; + + let oob_masked = match self.out_of_bounds { + CubeOption::Some(bounds) => !Coords2d::is_in_bounds(&pos, &bounds), + CubeOption::None => false, + }; + + causal_masked || oob_masked + } + + pub fn update_origin(&mut self, new_origin: Coords2d) { + self.logical_iter_origin.update(new_origin); + } +} + +#[derive(CubeType)] +pub struct MaterializedTileMask> { + fragment: AM::Mask, + logical_mask: LogicalTileMask, + #[cube(comptime)] + config: AM::Config, +} + +#[cube] +impl> MaterializedTileMask { + pub fn should_mask(&self, local_pos: Coords2d) -> bool { + let logical_masked = self.logical_mask.should_mask(local_pos); + let materialized_masked = self.fragment.should_mask(local_pos); + + logical_masked || materialized_masked + } + + pub fn update_tile(&mut self, tile: StridedTile>) { + AM::fill_mask(&tile, &mut self.fragment, self.config); + } +} + +#[derive(CubeType)] +pub enum MaskFragment> { + Materialized(MaterializedTileMask), + Logical(LogicalTileMask), +} + +#[cube] +impl> MaskFragment { + pub fn new( + out_of_bounds: CubeOption, + #[comptime] partition_pos: Coords2d, + #[comptime] config: AM::Config, + ) -> MaskFragment { + let logical_mask = LogicalTileMask:: { + logical_iter_origin: LogicalIterOrigin::dummy(), + partition_pos, + causal: config.causal_mask(), + out_of_bounds, + fragment_layout: AM::softmax_layout(config), + }; + + if config.materialized_mask() { + MaskFragment::new_Materialized(MaterializedTileMask:: { + fragment: AM::allocate_mask(config), + logical_mask, + config, + }) + } else { + MaskFragment::new_Logical(logical_mask) + } + } +} + +#[cube] +impl> MaskTile for MaskFragment { + type Fragment = AM::Mask; + type MaskPrecision = MSK; + + fn apply(this: &Self, local_pos: Coords2d) -> E { + let should_mask = match this { + MaskFragment::Materialized(materialized_tile_mask) => { + materialized_tile_mask.should_mask(local_pos) + } + MaskFragment::Logical(logical_tile_mask) => logical_tile_mask.should_mask(local_pos), + }; + + E::cast_from(should_mask) * E::min_value() + } + + fn fragment_mut(&mut self) -> &mut Self::Fragment { + match self { + MaskFragment::Materialized(materialized_tile_mask) => { + &mut materialized_tile_mask.fragment + } + MaskFragment::Logical(_) => { + panic!("Tried to get fragment of logical mask") + } + } + } + + fn update(&mut self, new_origin: Coords2d, tile: CubeOption>) { + match self { + MaskFragment::Materialized(materialized_tile_mask) => { + // TODO read the tile + materialized_tile_mask + .logical_mask + .update_origin(new_origin); + + materialized_tile_mask.update_tile(tile.unwrap()) + } + MaskFragment::Logical(logical_tile_mask) => logical_tile_mask.update_origin(new_origin), + } + } +} diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs index 3a0d94be5..d4d205494 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs @@ -1,9 +1,11 @@ mod accumulator; mod key_value; +mod mask; mod query; mod softmax; pub use accumulator::*; pub use key_value::*; +pub use mask::*; pub use query::*; pub use softmax::*; diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs index 50f96b2f8..574b0f8e1 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs @@ -1,11 +1,11 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_matmul::components::tile::StridedTile; use crate::components::AttentionPrecision; -use crate::components::attention_types::QT; -use crate::components::tile::QueryTile; +use crate::components::attention_types::*; use crate::components::tile::dummy::AttentionMatmul; +use crate::components::tile::{QueryTile, QueryTileExpand}; +use cubecl_matmul::components::tile::StridedTile; #[derive(CubeType)] pub struct QueryFragment> { @@ -14,14 +14,22 @@ pub struct QueryFragment> { #[cube] impl> QueryFragment { - pub fn new( - tile: &StridedTile, - #[comptime] config: AM::Config, - ) -> QueryFragment { + pub fn new(#[comptime] config: AM::Config) -> QueryFragment { QueryFragment:: { - fragment: AM::allocate_fill_query(tile, config), + fragment: AM::allocate_query(config), } } } -impl> QueryTile> for QueryFragment {} +#[cube] +impl> QueryTile for QueryFragment { + type Fragment = AM::Query; + + fn fragment_mut(&mut self) -> &mut Self::Fragment { + &mut self.fragment + } + + fn update(&mut self, tile: StridedTile>) { + AM::fill_query(&tile, &mut self.fragment) + } +} diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs index 6e2cfd020..fcada3e70 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs +++ b/crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs @@ -4,14 +4,14 @@ use cubecl_core::prelude::*; use crate::components::AttentionPrecision; use crate::components::attention_types::*; use crate::components::tile::BroadcastReducer; +use crate::components::tile::MaskTile; use crate::components::tile::RowWise; use crate::components::tile::dummy::AttentionMatmulConfig; -use crate::components::tile::{PlaneLayout, PlaneLayoutExpand}; -use crate::components::tile::{row_max, row_sum}; -use crate::components::{ - TileMask, - tile::{RunningState, SoftmaxTile, SoftmaxTileExpand, dummy::AttentionMatmul}, +use crate::components::tile::{FragmentOps, FragmentOpsExpand}; +use crate::components::tile::{ + RunningState, SoftmaxTile, SoftmaxTileExpand, dummy::AttentionMatmul, }; +use crate::components::tile::{row_max, row_sum}; #[derive(CubeType)] pub struct DummySoftmax> { @@ -33,7 +33,7 @@ impl> DummySoftmax { #[cube] impl> SoftmaxTile for DummySoftmax { - type PlaneLayout = AM::Softmax; + type FragmentOps = AM::Softmax; fn init_state(#[comptime] num_rows: u32) -> RunningState> { RunningState::>::init(num_rows) @@ -47,8 +47,8 @@ impl> SoftmaxTile for DummyS AM::zero_softmax(&mut self.fragment, self.config); } - fn scale_and_mask(&mut self, scale: SM, mask: TileMask) { - self.fragment.scale_and_mask(scale, mask); + fn scale_and_mask(this: &mut Self, scale: SM, mask: &M) { + Self::FragmentOps::scale_and_mask::(&mut this.fragment, scale, mask); } fn row_max( @@ -57,7 +57,7 @@ impl> SoftmaxTile for DummyS base: &RowWise>, #[comptime] config: TC, ) { - row_max::, Self::PlaneLayout, BroadcastReducer, TC>( + row_max::, Self::FragmentOps, BroadcastReducer, TC>( placeholder, base, &self.fragment, @@ -74,7 +74,7 @@ impl> SoftmaxTile for DummyS ) -> RowWise> { self.fragment.exp_m_diff(new_m); - row_sum::, Self::PlaneLayout, BroadcastReducer, TC>( + row_sum::, Self::FragmentOps, BroadcastReducer, TC>( rowsum_placeholder, &self.fragment, config, diff --git a/crates/cubecl-attention/src/components/tile/row/base.rs b/crates/cubecl-attention/src/components/tile/row/base.rs index 0f3f263e0..07112bc80 100644 --- a/crates/cubecl-attention/src/components/tile/row/base.rs +++ b/crates/cubecl-attention/src/components/tile/row/base.rs @@ -1,19 +1,31 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::components::TileMask; +use crate::components::tile::MaskTile; use crate::components::tile::RowWise; +use cubecl_std::tensor::layout::Coords2d; #[cube] -pub trait PlaneLayout: CubeType { - fn num_local_rows(&self) -> comptime_type!(u32); - fn num_local_cols(&self) -> comptime_type!(u32); +pub trait FragmentLayout: CubeType { + fn absolute_pos(&self, local_pos: Coords2d) -> Coords2d; fn num_units_per_row(&self) -> comptime_type!(u32); +} + +#[cube] +pub trait FragmentOps { + type Layout: FragmentLayout; fn rowwise_max(&self) -> RowWise; fn rowwise_sum(&self) -> RowWise; fn scale(&mut self, val: &RowWise); - fn scale_and_mask(&mut self, scale: E, mask: TileMask); + fn scale_and_mask(this: &mut Self, scale: E, mask: &M); fn exp_m_diff(&mut self, m: &RowWise); + + fn layout(&self) -> Self::Layout; +} + +#[cube] +pub trait FragmentMask: CubeType { + fn should_mask(&self, local_pos: Coords2d) -> bool; } diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/base.rs b/crates/cubecl-attention/src/components/tile/row/reduce/base.rs index 68386058b..d1c0587ce 100644 --- a/crates/cubecl-attention/src/components/tile/row/reduce/base.rs +++ b/crates/cubecl-attention/src/components/tile/row/reduce/base.rs @@ -1,46 +1,46 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::components::tile::PlaneLayout; +use crate::components::tile::FragmentOps; use crate::components::tile::RowMax; use crate::components::tile::RowSum; use crate::components::tile::RowWise; use crate::components::tile::dummy::AttentionMatmulConfig; #[cube] -pub fn row_sum, R: Reducer, TC: AttentionMatmulConfig>( +pub fn row_sum, R: Reducer, TC: AttentionMatmulConfig>( vals: &mut RowWise, - data: &PL, + data: &F, #[comptime] config: TC, ) { vals.copy_from(&RowWise::new_zero(vals.num_rows)); - R::reduce::(vals, data, config) + R::reduce::(vals, data, config) } #[cube] -pub fn row_max, R: Reducer, TC: AttentionMatmulConfig>( +pub fn row_max, R: Reducer, TC: AttentionMatmulConfig>( vals: &mut RowWise, base: &RowWise, - data: &PL, + data: &F, #[comptime] config: TC, ) { vals.copy_from(base); - R::reduce::(vals, data, config) + R::reduce::(vals, data, config) } #[cube] pub trait Reducer: CubeType { - fn reduce, RO: ReduceOp, TC: AttentionMatmulConfig>( + fn reduce, RO: ReduceOp, TC: AttentionMatmulConfig>( vals: &mut RowWise, - data: &PL, + data: &F, #[comptime] config: TC, ); } #[cube] pub trait ReduceOp { - fn reduce_local>(data: &PL) -> RowWise; - fn reduce_local_store>(data: &PL, acc: &mut RowWise); + fn reduce_local>(data: &F) -> RowWise; + fn reduce_local_store>(data: &F, acc: &mut RowWise); fn reduce_step_rowwise(acc: &mut RowWise, elem: &RowWise, mask: bool); fn reduce_step_scalar(a: E, b: E) -> E; } diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs b/crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs index 8eb8ba08f..a5abdcc82 100644 --- a/crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs +++ b/crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs @@ -4,7 +4,8 @@ use cubecl_core::prelude::*; use crate::components::tile::ReduceOp; use crate::components::tile::Reducer; use crate::components::tile::dummy::AttentionMatmulConfig; -use crate::components::tile::{PlaneLayout, PlaneLayoutExpand}; +use crate::components::tile::row::base::FragmentLayout; +use crate::components::tile::{FragmentLayoutExpand, FragmentOps, FragmentOpsExpand}; use crate::components::tile::{RowVal, RowWise}; #[derive(CubeType)] @@ -12,12 +13,12 @@ pub struct BroadcastReducer {} #[cube] impl Reducer for BroadcastReducer { - fn reduce, RO: ReduceOp, TC: AttentionMatmulConfig>( + fn reduce, RO: ReduceOp, TC: AttentionMatmulConfig>( vals: &mut RowWise, - data: &PL, + data: &F, #[comptime] config: TC, ) { - let num_units_per_row = data.num_units_per_row(); + let num_units_per_row = data.layout().num_units_per_row(); let num_shares_within_plane = comptime!((num_units_per_row as f32).log2().ceil() as u32); let unit_pos = UNIT_POS_X; @@ -25,7 +26,7 @@ impl Reducer for BroadcastReducer { let mut fpb = FakePlaneBroadcast::::new(config.plane_dim(), config.num_planes()); - RO::reduce_local_store::(data, vals); + RO::reduce_local_store::(data, vals); for i in 0..num_shares_within_plane { let offset = num_units_per_row >> (i + 1); diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/dummy_reducer.rs b/crates/cubecl-attention/src/components/tile/row/reduce/dummy_reducer.rs index a5d9b000e..99629bcc3 100644 --- a/crates/cubecl-attention/src/components/tile/row/reduce/dummy_reducer.rs +++ b/crates/cubecl-attention/src/components/tile/row/reduce/dummy_reducer.rs @@ -5,22 +5,23 @@ use crate::components::tile::ReduceOp; use crate::components::tile::Reducer; use crate::components::tile::RowWise; use crate::components::tile::dummy::AttentionMatmulConfig; -use crate::components::tile::{PlaneLayout, PlaneLayoutExpand}; +use crate::components::tile::row::base::FragmentLayout; +use crate::components::tile::{FragmentLayoutExpand, FragmentOps, FragmentOpsExpand}; #[derive(CubeType)] pub struct DummyReducer {} #[cube] impl Reducer for DummyReducer { - fn reduce, RO: ReduceOp, TC: AttentionMatmulConfig>( + fn reduce, RO: ReduceOp, TC: AttentionMatmulConfig>( vals: &mut RowWise, - data: &PL, + data: &F, #[comptime] config: TC, ) { let num_vals_in_plane = config.num_rows_per_unit() * config.plane_dim(); let mut smem = SharedMemory::::new(num_vals_in_plane * config.num_planes()); - let local_vals = RO::reduce_local::(data); + let local_vals = RO::reduce_local::(data); let plane_offset = UNIT_POS_Y * num_vals_in_plane; let unit_offset = UNIT_POS_X; @@ -35,15 +36,16 @@ impl Reducer for DummyReducer { sync_cube(); + let num_units_per_row = data.layout().num_units_per_row(); + #[unroll] for r in 0..config.num_rows_per_unit() { let mut val = vals.index(r); let row_offset = r * config.plane_dim(); - for c in 0..data.num_units_per_row() { - let unit_offset = - (UNIT_POS_X / data.num_units_per_row()) * data.num_units_per_row(); + for c in 0..num_units_per_row { + let unit_offset = (UNIT_POS_X / num_units_per_row) * num_units_per_row; let offset = plane_offset + row_offset + unit_offset; val = RO::reduce_step_scalar(val, smem[offset + c]); diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs b/crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs index 6d88b9712..fbc1567f7 100644 --- a/crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs +++ b/crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs @@ -3,7 +3,7 @@ use cubecl_core::prelude::*; use crate::components::tile::ReduceOp; use crate::components::tile::RowWise; -use crate::components::tile::{PlaneLayout, PlaneLayoutExpand}; +use crate::components::tile::{FragmentOps, FragmentOpsExpand}; #[derive(CubeType)] pub struct RowMax {} @@ -13,12 +13,12 @@ pub struct RowSum {} #[cube] impl ReduceOp for RowMax { - fn reduce_local>(data: &PL) -> RowWise { + fn reduce_local>(data: &F) -> RowWise { data.rowwise_max() } - fn reduce_local_store>(data: &PL, acc: &mut RowWise) { - acc.max_inplace(&Self::reduce_local::(data)) + fn reduce_local_store>(data: &F, acc: &mut RowWise) { + acc.max_inplace(&Self::reduce_local::(data)) } fn reduce_step_rowwise(acc: &mut RowWise, elem: &RowWise, mask: bool) { @@ -35,12 +35,12 @@ impl ReduceOp for RowMax { #[cube] impl ReduceOp for RowSum { - fn reduce_local>(data: &PL) -> RowWise { + fn reduce_local>(data: &F) -> RowWise { data.rowwise_sum() } - fn reduce_local_store>(data: &PL, acc: &mut RowWise) { - acc.add_inplace(&Self::reduce_local::(data)) + fn reduce_local_store>(data: &F, acc: &mut RowWise) { + acc.add_inplace(&Self::reduce_local::(data)) } fn reduce_step_rowwise(acc: &mut RowWise, elem: &RowWise, mask: bool) { diff --git a/crates/cubecl-attention/src/components/tile/row/rowwise.rs b/crates/cubecl-attention/src/components/tile/row/rowwise.rs index 604dcb950..96e3a6bdd 100644 --- a/crates/cubecl-attention/src/components/tile/row/rowwise.rs +++ b/crates/cubecl-attention/src/components/tile/row/rowwise.rs @@ -2,19 +2,19 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[derive(CubeType)] -pub struct RowWise { +pub struct RowWise { #[cube(comptime)] pub num_rows: u32, pub vals: Sequence>, } #[derive(CubeType)] -pub struct RowVal { +pub struct RowVal { pub val: E, } #[cube] -impl RowWise { +impl RowWise { pub fn new_filled(#[comptime] num_rows: u32, val: E) -> RowWise { let mut vals = Sequence::new(); #[unroll] @@ -68,14 +68,6 @@ impl RowWise { } } - pub fn recip_inplace(&mut self) { - #[unroll] - for i in 0..self.num_rows { - let row_val = self.vals.index_mut(i); - row_val.val = Recip::recip(row_val.val); - } - } - pub fn max_inplace(&mut self, other: &RowWise) { #[unroll] for i in 0..self.num_rows { @@ -104,12 +96,12 @@ impl RowWise { } } - pub fn exp_m_diff(&self, other: &RowWise) -> RowWise { + pub fn mul(&self, other: &RowWise) -> RowWise { let mut vals = Sequence::new(); #[unroll] for i in 0..self.num_rows { - let val = Exp::exp(self.index(i) - other.index(i)); + let val = self.index(i) * other.index(i); vals.push(RowVal:: { val }); } @@ -119,12 +111,12 @@ impl RowWise { } } - pub fn mul(&self, other: &RowWise) -> RowWise { + pub fn add(&self, other: &RowWise) -> RowWise { let mut vals = Sequence::new(); #[unroll] for i in 0..self.num_rows { - let val = self.index(i) * other.index(i); + let val = self.index(i) + other.index(i); vals.push(RowVal:: { val }); } @@ -133,14 +125,20 @@ impl RowWise { vals, } } +} - pub fn add(&self, other: &RowWise) -> RowWise { +#[cube] +impl RowWise { + pub fn exp_m_diff(&self, other: &RowWise) -> RowWise { let mut vals = Sequence::new(); + let mut i = comptime![0u32]; #[unroll] - for i in 0..self.num_rows { - let val = self.index(i) + other.index(i); + for _ in 0..self.num_rows { + let val = Exp::exp(self.index(i) - other.index(i)); vals.push(RowVal:: { val }); + + comptime![i += 1]; } RowWise:: { @@ -148,4 +146,15 @@ impl RowWise { vals, } } + + pub fn recip_inplace(&mut self) { + let mut i = comptime![0u32]; + #[unroll] + for _ in 0..self.num_rows { + let row_val = self.vals.index_mut(i); + row_val.val = Recip::recip(row_val.val); + + comptime![i += 1]; + } + } } diff --git a/crates/cubecl-attention/src/components/tile/tiles.rs b/crates/cubecl-attention/src/components/tile/tiles.rs index 623bf7040..8725a2189 100644 --- a/crates/cubecl-attention/src/components/tile/tiles.rs +++ b/crates/cubecl-attention/src/components/tile/tiles.rs @@ -2,13 +2,20 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use crate::components::AttentionPrecision; -use crate::components::TileMask; use crate::components::attention_types::*; use crate::components::tile::dummy::AttentionMatmulConfig; -use crate::components::tile::{PlaneLayout, RowWise, RunningState}; +use crate::components::tile::{FragmentOps, RowWise, RunningState}; +use cubecl_matmul::components::tile::StridedTile; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; #[cube] -pub trait QueryTile: CubeType {} +pub trait QueryTile: CubeType { + type Fragment: CubeType; + + fn fragment_mut(&mut self) -> &mut Self::Fragment; + fn update(&mut self, tile: StridedTile>); +} #[cube] pub trait KeyValueTile: CubeType { @@ -24,14 +31,14 @@ pub trait KeyValueTile: CubeType { #[cube] pub trait SoftmaxTile: CubeType { - type PlaneLayout: PlaneLayout>; + type FragmentOps: FragmentOps>; fn init_state(#[comptime] num_rows: u32) -> RunningState>; fn init_placeholder(#[comptime] num_rows: u32) -> RowWise>; fn zero(&mut self); - fn scale_and_mask(&mut self, scale: SM, mask: TileMask); + fn scale_and_mask(this: &mut Self, scale: SM, mask: &M); fn row_max( &self, @@ -56,3 +63,13 @@ pub trait AccumulatorTile: CubeType { fn scale_mul(&mut self, scale: &RowWise>); fn scale_div(&mut self, scale: &RowWise>); } + +#[cube] +pub trait MaskTile: CubeType { + type Fragment: CubeType; + type MaskPrecision: Numeric; + + fn apply(this: &Self, local_pos: Coords2d) -> E; + fn fragment_mut(&mut self) -> &mut Self::Fragment; + fn update(&mut self, new_origin: Coords2d, tile: CubeOption>); +} diff --git a/crates/cubecl-attention/src/tests/attention_test_launcher.rs b/crates/cubecl-attention/src/tests/attention_test_launcher.rs index fac1c682f..438e39200 100644 --- a/crates/cubecl-attention/src/tests/attention_test_launcher.rs +++ b/crates/cubecl-attention/src/tests/attention_test_launcher.rs @@ -1,6 +1,7 @@ use cubecl_core::prelude::*; use cubecl_core::server::Allocation; use cubecl_core::{CubeElement, server}; +use cubecl_std::CubeOptionArgs; use crate::components::args::TensorInputsLaunch; use crate::components::batch::BatchAttentionConfig; @@ -45,7 +46,15 @@ pub fn test_attention_algorithm( let query = tensor_raw_parts_input::(&client, &problem, AttentionIdent::Query, 12); let key = tensor_raw_parts_input::(&client, &problem, AttentionIdent::Key, 34); let value = tensor_raw_parts_input::(&client, &problem, AttentionIdent::Value, 56); - // let mask = tensor_raw_parts_input::(&client, &problem, Ident::Mask, 78); + let mask = match problem.masked { + true => Some(tensor_raw_parts_input::( + &client, + &problem, + AttentionIdent::Mask, + 78, + )), + false => None, + }; let out = tensor_raw_parts_output::(&client, &problem); let line_sizes = AvailableLineSizes::from_elem_types::( @@ -58,7 +67,6 @@ pub fn test_attention_algorithm( .filter_with_tensor(AttentionIdent::Query, &query.strides, &query.shape) .filter_with_tensor(AttentionIdent::Key, &key.strides, &key.shape) .filter_with_tensor(AttentionIdent::Value, &value.strides, &value.shape) - // .filter_with_tensor(Ident::Mask, &mask.strides, &mask.shape) .filter_with_tensor(AttentionIdent::Out, &out.strides, &out.shape) .pick_max() .unwrap(); @@ -104,12 +112,15 @@ pub fn test_attention_algorithm( &value.shape, line_sizes.value, ), - // TensorArg::::from_raw_parts::( - // &mask.handle, - // &mask.strides, - // &mask.shape, - // line_sizes.mask, - // ), + match mask.as_ref() { + Some(m) => CubeOptionArgs::Some(TensorArg::::from_raw_parts::( + &m.handle, + &m.strides, + &m.shape, + line_sizes.mask, + )), + None => CubeOptionArgs::None, + }, ), TensorArg::::from_raw_parts::( &out.handle, @@ -126,7 +137,8 @@ pub fn test_attention_algorithm( &query.original_data.unwrap(), &key.original_data.unwrap(), &value.original_data.unwrap(), - None, + mask.as_ref() + .map(|m| m.original_data.as_ref().unwrap().as_slice()), &problem, &client, out.handle, diff --git a/crates/cubecl-attention/src/tests/macros/mod.rs b/crates/cubecl-attention/src/tests/macros/mod.rs index aa19491d0..376aafd0a 100644 --- a/crates/cubecl-attention/src/tests/macros/mod.rs +++ b/crates/cubecl-attention/src/tests/macros/mod.rs @@ -75,6 +75,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -115,6 +116,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -125,6 +127,7 @@ macro_rules! testgen_attention { } #[test] + #[ignore = "Disabled, should only work for unit attention"] fn attention_9_9_9_9() { let client = TestRuntime::client(&Default::default()); let tile_size = AttentionTileSize { @@ -153,6 +156,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -163,6 +167,7 @@ macro_rules! testgen_attention { } #[test] + #[ignore = "Disabled, should only work for unit attention"] fn attention_7_3_10_10() { let client = TestRuntime::client(&Default::default()); let tile_size = AttentionTileSize { @@ -191,6 +196,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -229,6 +235,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -270,6 +277,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -308,6 +316,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -346,6 +355,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -384,6 +394,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -422,6 +433,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -460,6 +472,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -498,6 +511,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -538,6 +552,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -578,6 +593,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -617,6 +633,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -656,6 +673,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -694,6 +712,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -732,6 +751,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -770,6 +790,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -804,11 +825,11 @@ macro_rules! testgen_attention { batch: 1, num_heads: 1, seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - // seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize + 9, - seq_kv: 8, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize + 9, head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -847,6 +868,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -885,6 +907,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -923,6 +946,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -961,6 +985,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -999,6 +1024,7 @@ macro_rules! testgen_attention { head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, masked: false, + causal: false, }; attention_test_launch::( client, @@ -1011,43 +1037,245 @@ macro_rules! testgen_attention { ); } - // #[test] - // fn attention_double_row_wise() { - // let client = TestRuntime::client(&Default::default()); - // let tile_size = AttentionTileSize { - // seq_q: 16, - // seq_kv: 16, - // head_dim: 16, - // val_dim: 16, - // }; - // let partition_size = AttentionPartitionSize { - // seq_q: 2, - // seq_kv: 2, - // head_dim: 2, - // val_dim: 2, - // }; - // let stage_size = AttentionStageSize { seq_q: 2 }; - // let tiling_scheme = AttentionTilingScheme { - // tile_size, - // partition_size, - // stage_size, - // }; - // let problem = AttentionProblem { - // batch: 1, - // num_heads: 1, - // seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - // seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - // head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - // val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - // masked: false, - // }; - // attention_test_launch::( - // client, - // tiling_scheme, - // problem, - // true, - // ); - // } + #[test] + #[ignore = "TODO debug"] + fn attention_double_row_wise() { + let client = TestRuntime::client(&Default::default()); + let tile_size = AttentionTileSize { + seq_q: 16, + seq_kv: 16, + head_dim: 16, + val_dim: 16, + }; + let partition_size = AttentionPartitionSize { + seq_q: 2, + seq_kv: 2, + head_dim: 2, + val_dim: 2, + }; + let stage_size = AttentionStageSize { seq_q: 2 }; + let tiling_scheme = AttentionTilingScheme { + tile_size, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + TestOptions { + two_rows_in_array_tile: true, + ..Default::default() + }, + ); + } + + #[test] + fn attention_8_8_8_8_masked() { + let client = TestRuntime::client(&Default::default()); + let tile_size = AttentionTileSize { + seq_q: 8, + seq_kv: 8, + head_dim: 8, + val_dim: 8, + }; + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { seq_q: 1 }; + let tiling_scheme = AttentionTilingScheme { + tile_size, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_8_8_8_8_causal() { + let client = TestRuntime::client(&Default::default()); + let tile_size = AttentionTileSize { + seq_q: 8, + seq_kv: 8, + head_dim: 8, + val_dim: 8, + }; + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { seq_q: 1 }; + let tiling_scheme = AttentionTilingScheme { + tile_size, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: true, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_8_8_8_8_masked_causal() { + let client = TestRuntime::client(&Default::default()); + let tile_size = AttentionTileSize { + seq_q: 8, + seq_kv: 8, + head_dim: 8, + val_dim: 8, + }; + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { seq_q: 1 }; + let tiling_scheme = AttentionTilingScheme { + tile_size, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: true, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + #[ignore = "TODO debug"] + fn attention_masked_oob() { + let client = TestRuntime::client(&Default::default()); + let tile_size = AttentionTileSize { + seq_q: 8, + seq_kv: 8, + head_dim: 8, + val_dim: 8, + }; + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { seq_q: 1 }; + let tiling_scheme = AttentionTilingScheme { + tile_size, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize - 1, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + #[ignore = "TODO debug"] + fn attention_masked_larger() { + let client = TestRuntime::client(&Default::default()); + let tile_size = AttentionTileSize { + seq_q: 8, + seq_kv: 8, + head_dim: 8, + val_dim: 8, + }; + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { seq_q: 1 }; + let tiling_scheme = AttentionTilingScheme { + tile_size, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * 2, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } } }; } diff --git a/crates/cubecl-attention/src/tests/test_utils.rs b/crates/cubecl-attention/src/tests/test_utils.rs index 8f30b6db5..c67d296e8 100644 --- a/crates/cubecl-attention/src/tests/test_utils.rs +++ b/crates/cubecl-attention/src/tests/test_utils.rs @@ -56,7 +56,7 @@ where query: &[EG], key: &[EG], value: &[EG], - mask: Option<&[u8]>, + mask: Option<&[Self::EM]>, problem: &AttentionProblem, client: &ComputeClient, out: server::Handle, @@ -266,7 +266,6 @@ sample_float!(half::f16); sample_float!(half::bf16); sample_float!(f32); sample_float!(f64); -sample_float!(u8); impl Sampleable for flex32 { fn sample( @@ -323,6 +322,21 @@ impl Sampleable for bool { } } +impl Sampleable for u8 { + fn sample( + client: &ComputeClient, + shape: &[usize], + seed: u64, + ) -> TensorHandle { + cubecl_random::seed(seed); + let output = TensorHandle::::empty(client, shape.to_vec()); + + cubecl_random::random_bernoulli::(client, 0.5, output.as_ref()); + + output + } +} + pub(crate) fn flash_attention_v2_cpu( query: &[P::EG], key: &[P::EG], @@ -334,11 +348,15 @@ where { let batch = problem.batch; let seq_q = problem.seq_q; - let seq_k = problem.seq_kv; + let seq_kv = problem.seq_kv; let num_heads = problem.num_heads; let head_dim = problem.head_dim; let val_dim = problem.val_dim; + let masked = mask.is_some(); + println!("{:?}", problem.masked); + println!("{:?}", mask); + assert!(problem.masked == masked); // Precompute strides for indexing let query_strides = strides(problem, AttentionIdent::Query); @@ -364,8 +382,8 @@ where // For each K/V block let mut k_block_start = 0usize; - while k_block_start < seq_k { - let k_block_end = std::cmp::min(seq_k, k_block_start + seq_k); + while k_block_start < seq_kv { + let k_block_end = std::cmp::min(seq_kv, k_block_start + seq_kv); let cur_block_len = k_block_end - k_block_start; // Step A: compute S_block[j'] = Q_i ยท K_{j'} for j' in block @@ -390,13 +408,15 @@ where // apply scale (1/sqrt(dk)) dot *= scale; - // apply mask (for masked positions set -inf) - let s_val = if masked { + let s_val = if problem.causal && j > i { + P::EA::new(f32::NEG_INFINITY) + } else if masked { let m_idx = b * mask_strides[0] + i * mask_strides[1] + h * mask_strides[2] + j * mask_strides[3]; let m_val = mask.unwrap()[m_idx].cast_into(); + if m_val != P::EM::from_int(0) { P::EA::new(f32::NEG_INFINITY) } else { diff --git a/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs b/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs index fee6221b1..67b806c0e 100644 --- a/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs +++ b/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs @@ -113,7 +113,6 @@ where // Create barriers and prefetch each stage #[unroll] - #[allow(clippy::explicit_counter_loop)] for stage in 0..num_stages { let barrier = Barrier::new_with_tma_proxy(BarrierLevel::cube_coop(0u32)); @@ -133,7 +132,6 @@ where // Loop through all stages #[unroll] - #[allow(clippy::explicit_counter_loop)] for stage in 0..num_stages { let k = k + stage; let next_k = k + num_stages; diff --git a/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs b/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs index 2baa50b3d..f56e87fe7 100644 --- a/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs +++ b/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs @@ -113,7 +113,6 @@ impl let len = L::Job::task_count(&loading_job); - #[allow(clippy::explicit_counter_loop)] #[unroll] for task_id in 0..len { L::Job::::execute_task::( diff --git a/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs b/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs index 841e43595..3baa39f57 100644 --- a/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs +++ b/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs @@ -122,7 +122,6 @@ impl let len = L::Job::task_count(&loading_job); - #[allow(clippy::explicit_counter_loop)] #[unroll] for task_id in 0..len { L::Job::::execute_task::( diff --git a/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs b/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs index 9e6eb3d02..941cce204 100644 --- a/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs +++ b/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs @@ -153,12 +153,10 @@ where let rhs_load_total = comptime!(n_iterations * k_iterations); let execute_total = comptime!(m_iterations * n_iterations * k_iterations); - #[allow(clippy::explicit_counter_loop)] #[unroll] for k_iter in 0..k_iterations { let k_load_iter = partition_scheduler.map_k(k_iter); - #[allow(clippy::explicit_counter_loop)] #[unroll] for m_iter in 0..m_iterations { let m_load_iter = partition_scheduler.map_m(m_iter); @@ -181,7 +179,6 @@ where } #[unroll] - #[allow(clippy::explicit_counter_loop)] for n_iter in 0..n_iterations { let n_load_iter = partition_scheduler.map_n(n_iter); @@ -197,7 +194,6 @@ where ); comptime!(rhs_load_counter += 1); - #[allow(clippy::explicit_counter_loop)] #[unroll] for m_iter in 0..m_iterations { let accumulator = @@ -254,12 +250,10 @@ where let rhs_load_total = comptime!(n_iterations * k_iterations); let execute_total = comptime!(m_iterations * n_iterations * k_iterations); - #[allow(clippy::explicit_counter_loop)] #[unroll] for k_iter in 0..k_iterations { let k_load_iter = partition_scheduler.map_k(k_iter); - #[allow(clippy::explicit_counter_loop)] #[unroll] for m_iter in 0..m_iterations { let m_load_iter = partition_scheduler.map_m(m_iter); @@ -318,7 +312,6 @@ where ); comptime!(rhs_load_counter += 1); - #[allow(clippy::explicit_counter_loop)] #[unroll] for m_iter in 0..m_iterations { let accumulator = @@ -350,7 +343,6 @@ where &mut rhs_fragments.1 }; - #[allow(clippy::explicit_counter_loop)] #[unroll] for m_iter in 0..m_iterations { let accumulator = diff --git a/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs b/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs index d88deb50a..dc8834ba9 100644 --- a/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs +++ b/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs @@ -177,12 +177,10 @@ where // Iterate over each tile in the partition #[unroll] - #[allow(clippy::explicit_counter_loop)] for m_iter in 0..m_iterations { let m_load_iter = partition_scheduler.map_m(m_iter); #[unroll] - #[allow(clippy::explicit_counter_loop)] for n_iter in 0..n_iterations { let n_load_iter = partition_scheduler.map_n(n_iter);