Skip to content

Commit 68dca01

Browse files
authored
feat: Manual mma tile (#935)
1 parent c47fc84 commit 68dca01

File tree

26 files changed

+900
-81
lines changed

26 files changed

+900
-81
lines changed

crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use cubecl_core as cubecl;
22
use cubecl_core::{cmma, prelude::*};
33
use cubecl_matmul::components::tile::StridedTile;
44

5+
use crate::components::AttentionPrecision;
56
use crate::components::attention_types::*;
67
use crate::components::tile::dummy::accelerated::AcceleratedAttentionMatmulConfig;
78
use crate::components::tile::dummy::{AttentionMatmul, AttentionMatmulConfig as _};
8-
use crate::components::{AttentionIdent, AttentionPrecision};
99

1010
/// Performs two matmuls with fragment reuse for key/value and score/prob
1111
pub struct AcceleratedAttentionMatmul;
@@ -40,7 +40,7 @@ impl<AP: AttentionPrecision> AttentionMatmul<AP> for AcceleratedAttentionMatmul
4040
tile: &StridedTile<EI>,
4141
#[comptime] config: Self::Config,
4242
) -> Self::Query {
43-
let (slice, stride) = tile.as_unlined(config.stage_line_size(AttentionIdent::Query));
43+
let (slice, stride) = tile.as_unlined();
4444
let size = config.attention_tile_size().to_score_matmul_tile_size();
4545

4646
if config.cast_query() {
@@ -118,9 +118,9 @@ impl<AP: AttentionPrecision> AttentionMatmul<AP> for AcceleratedAttentionMatmul
118118
fn fill_key_value<E: Numeric>(
119119
tile: &StridedTile<E>,
120120
rhs: &mut Self::KeyValue,
121-
#[comptime] config: Self::Config,
121+
#[comptime] _config: Self::Config,
122122
) {
123-
let (slice, stride) = tile.as_unlined(config.stage_line_size(AttentionIdent::Key));
123+
let (slice, stride) = tile.as_unlined();
124124
cmma::load(rhs, &slice, stride);
125125
}
126126

@@ -175,17 +175,17 @@ impl<AP: AttentionPrecision> AttentionMatmul<AP> for AcceleratedAttentionMatmul
175175
fn tmp_fill_accumulator(
176176
tile: &StridedTile<ACC<AP>>,
177177
acc: &mut Self::Accumulator,
178-
#[comptime] config: Self::Config,
178+
#[comptime] _config: Self::Config,
179179
) {
180-
let (slice, stride) = tile.as_unlined(config.stage_line_size(AttentionIdent::Out));
180+
let (slice, stride) = tile.as_unlined();
181181
cmma::load_with_layout(acc, &slice, stride, cmma::MatrixLayout::RowMajor);
182182
}
183183
fn tmp_fill_prob(
184184
tile: &StridedTile<SM<AP>>,
185185
prob: &mut Self::Softmax,
186186
#[comptime] _config: Self::Config,
187187
) {
188-
let (slice, stride) = tile.as_unlined(1u32);
188+
let (slice, stride) = tile.as_unlined();
189189
cmma::load_with_layout(prob, &slice, stride, cmma::MatrixLayout::RowMajor);
190190
}
191191
fn tmp_write_softmax(

crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ impl<AP: AttentionPrecision> AttentionMatmul<AP> for DummyRegisterAttentionMatmu
117117
if UNIT_POS_X == 0 {
118118
let size = config.attention_tile_size().key_size();
119119
for i in 0..size {
120-
rhs[i] = KVT::<AP>::cast_from(tile.as_unlined(1u32).0[i]);
120+
rhs[i] = KVT::<AP>::cast_from(tile.as_unlined().0[i]);
121121
}
122122
}
123123

@@ -176,7 +176,7 @@ impl<AP: AttentionPrecision> AttentionMatmul<AP> for DummyRegisterAttentionMatmu
176176
if UNIT_POS_X == 0 {
177177
let size = config.attention_tile_size().accumulator_size();
178178
for i in 0..size {
179-
acc[i] = tile.as_unlined(1u32).0[i];
179+
acc[i] = tile.as_unlined().0[i];
180180
}
181181
}
182182

@@ -191,7 +191,7 @@ impl<AP: AttentionPrecision> AttentionMatmul<AP> for DummyRegisterAttentionMatmu
191191
if UNIT_POS_X == 0 {
192192
let len = config.attention_tile_size().softmax_size();
193193
for i in 0..len {
194-
prob[i] = tile.as_unlined(1u32).0[i];
194+
prob[i] = tile.as_unlined().0[i];
195195
}
196196
}
197197

crates/cubecl-core/src/frontend/cmma.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,20 @@ impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B,
384384
})
385385
}
386386

387+
/// Returns the number of lines of size `line_size` with layout `line_layout` per lane.
388+
///
389+
/// # Note
390+
/// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
391+
/// to a cube.
392+
#[allow(unused)]
393+
pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394+
intrinsic!(|scope| {
395+
let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396+
let line_size = self.__expand_line_size_method(scope, ident);
397+
elems / line_size
398+
})
399+
}
400+
387401
/// The layout of each line in this matrix (row major or column major)
388402
#[allow(unused)]
389403
pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
@@ -419,7 +433,7 @@ impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B,
419433
/// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
420434
/// to a cube.
421435
#[allow(unused_variables)]
422-
pub fn indices_of_nth(
436+
pub fn position_of_nth(
423437
&self,
424438
lane_id: u32,
425439
elem_idx: u32,

crates/cubecl-core/src/runtime_tests/cmma.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ pub fn kernel_manual<A: Numeric, B: Numeric, CD: Numeric>(
804804
#[unroll]
805805
for k in 0..line_size_a {
806806
let n_elem = i * line_size_a + k;
807-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::A);
807+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::A);
808808
let value = a[row * size_k + col];
809809
reg[k] = value;
810810
}
@@ -818,7 +818,7 @@ pub fn kernel_manual<A: Numeric, B: Numeric, CD: Numeric>(
818818
#[unroll]
819819
for k in 0..line_size_b {
820820
let n_elem = i * line_size_b + k;
821-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::B);
821+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::B);
822822
let value = b[row * size_n + col];
823823
reg[k] = value;
824824
}
@@ -832,7 +832,7 @@ pub fn kernel_manual<A: Numeric, B: Numeric, CD: Numeric>(
832832
#[unroll]
833833
for k in 0..line_size_c {
834834
let n_elem = i * line_size_c + k;
835-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
835+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
836836
let value = c[row * size_n + col];
837837
reg[k] = value;
838838
}
@@ -848,7 +848,7 @@ pub fn kernel_manual<A: Numeric, B: Numeric, CD: Numeric>(
848848
#[unroll]
849849
for k in 0..line_size_d {
850850
let n_elem = i * line_size_d + k;
851-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
851+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
852852
out[row * size_n + col] = reg[k];
853853
}
854854
}
@@ -996,7 +996,7 @@ pub fn kernel_scaled<A: CubePrimitive, B: CubePrimitive, CD: Numeric, S: Numeric
996996
#[unroll]
997997
for i in 0..line_count_a {
998998
let n_elem = i * line_size_a * a_pack;
999-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::A);
999+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::A);
10001000
let idx = row * size_k + col;
10011001
let idx = idx / (a.line_size() * a_pack);
10021002
let value = a[idx];
@@ -1014,7 +1014,7 @@ pub fn kernel_scaled<A: CubePrimitive, B: CubePrimitive, CD: Numeric, S: Numeric
10141014
#[unroll]
10151015
for i in 0..line_count_b {
10161016
let n_elem = i * line_size_b * b_pack;
1017-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::B);
1017+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::B);
10181018
let idx = col * size_k + row;
10191019
let idx = idx / (b.line_size() * b_pack);
10201020
let value = b[idx];
@@ -1032,7 +1032,7 @@ pub fn kernel_scaled<A: CubePrimitive, B: CubePrimitive, CD: Numeric, S: Numeric
10321032
#[unroll]
10331033
for i in 0..line_count_c {
10341034
let n_elem = i * line_size_c;
1035-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
1035+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
10361036
let idx = row * size_n + col;
10371037
let value = c[idx / c.line_size()];
10381038
registers_c.push(value)
@@ -1050,7 +1050,7 @@ pub fn kernel_scaled<A: CubePrimitive, B: CubePrimitive, CD: Numeric, S: Numeric
10501050
#[unroll]
10511051
for i in 0..line_count_d {
10521052
let n_elem = i * line_size_d;
1053-
let (row, col) = def.indices_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
1053+
let (row, col) = def.position_of_nth(lane_id, n_elem, MatrixIdent::Accumulator);
10541054
let idx = row * size_n + col;
10551055
out[idx / out.line_size()] = registers_d[i];
10561056
}

crates/cubecl-cuda/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ mod tests {
7878
cubecl_std::testgen!();
7979

8080
cubecl_matmul::testgen_matmul_plane_accelerated!();
81+
cubecl_matmul::testgen_matmul_plane_mma!();
8182
cubecl_matmul::testgen_matmul_plane_vecmat!();
8283
cubecl_matmul::testgen_matmul_unit!();
8384
cubecl_matmul::testgen_matmul_tma!();

crates/cubecl-matmul/src/components/error.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use cubecl_core::{CubeCount, CubeDim, LineSizeError, ir::StorageType};
22
use std::fmt::{Debug, Display};
33

4-
use crate::components::TileSize;
4+
use crate::components::{MatrixLayout, TileSize};
55

66
/// Errors that can occur during the setup phase of a matmul operation.
77
pub enum MatmulSetupError {
@@ -41,6 +41,12 @@ pub enum MatmulAvailabilityError {
4141
size: Option<TileSize>,
4242
},
4343

44+
/// The layout of the matmul is unsupported
45+
LayoutUnsupported {
46+
lhs: MatrixLayout,
47+
rhs: MatrixLayout,
48+
},
49+
4450
/// Barrier synchronization is not available in the runtime.
4551
BarrierUnavailable,
4652

@@ -139,6 +145,12 @@ impl Debug for MatmulAvailabilityError {
139145
size.n(),
140146
size.k()
141147
),
148+
MatmulAvailabilityError::LayoutUnsupported { lhs, rhs } => {
149+
writeln!(
150+
f,
151+
"Cmma with layouts lhs {lhs:?} and rhs {rhs:?} not supported."
152+
)
153+
}
142154
MatmulAvailabilityError::CmmaInstructionUnavailable {
143155
lhs,
144156
rhs,

crates/cubecl-matmul/src/components/tile/accelerated/config.rs

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use cubecl_core::Runtime;
2-
use cubecl_core::ir::{ElemType, FloatKind};
2+
use cubecl_core::client::ComputeClient;
33
use cubecl_core::prelude::Numeric;
4-
use cubecl_core::{client::ComputeClient, ir::StorageType};
54
use cubecl_runtime::MmaConfig;
65

76
use crate::components::error::{MatmulAvailabilityError, MatmulSetupError};
@@ -100,31 +99,11 @@ impl AcceleratedConfig {
10099
let rhs = Rhs::as_type_native_unchecked();
101100
let acc = Acc::as_type_native_unchecked();
102101

103-
let lhs = match lhs {
104-
StorageType::Scalar(ElemType::Float(FloatKind::Flex32)) => {
105-
ElemType::Float(FloatKind::F32).into()
106-
}
107-
_ => lhs,
108-
};
109-
let rhs = match rhs {
110-
StorageType::Scalar(ElemType::Float(FloatKind::Flex32)) => {
111-
ElemType::Float(FloatKind::F32).into()
112-
}
113-
_ => rhs,
114-
};
115-
116-
let ea = match acc {
117-
StorageType::Scalar(ElemType::Float(FloatKind::Flex32)) => {
118-
ElemType::Float(FloatKind::F32).into()
119-
}
120-
_ => acc,
121-
};
122-
123102
let size = self.tile_size();
124103
if !client.properties().features.cmma.contains(&MmaConfig {
125104
a_type: lhs,
126105
b_type: rhs,
127-
cd_type: ea,
106+
cd_type: acc,
128107
m: size.m(),
129108
k: size.k(),
130109
n: size.n(),
@@ -133,7 +112,7 @@ impl AcceleratedConfig {
133112
MatmulAvailabilityError::CmmaInstructionUnavailable {
134113
lhs,
135114
rhs,
136-
output: ea,
115+
output: acc,
137116
size: Some(TileSize::new(size.m(), size.n(), size.k())),
138117
},
139118
));

crates/cubecl-matmul/src/components/tile/accelerated/matmul.rs

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,27 +72,17 @@ where
7272
fn load_lhs<E: Numeric>(
7373
tile: &StridedTile<E>,
7474
lhs: &mut Self::LhsFragment,
75-
#[comptime] config: Self::Config,
75+
#[comptime] _config: Self::Config,
7676
) {
77-
CmmaStageReader::<Self::LhsTile>::load_fragment(
78-
tile,
79-
lhs,
80-
CubeOption::new_None(),
81-
config.stage_line_size(StageIdent::Lhs),
82-
);
77+
CmmaStageReader::<Self::LhsTile>::load_fragment(tile, lhs, CubeOption::new_None());
8378
}
8479

8580
fn load_rhs<E: Numeric>(
8681
tile: &StridedTile<E>,
8782
rhs: &mut Self::RhsFragment,
88-
#[comptime] config: Self::Config,
83+
#[comptime] _config: Self::Config,
8984
) {
90-
CmmaStageReader::<Self::RhsTile>::load_fragment(
91-
tile,
92-
rhs,
93-
CubeOption::new_None(),
94-
config.stage_line_size(StageIdent::Rhs),
95-
);
85+
CmmaStageReader::<Self::RhsTile>::load_fragment(tile, rhs, CubeOption::new_None());
9686
}
9787

9888
fn load_acc<E: Numeric>(
@@ -101,12 +91,7 @@ where
10191
#[comptime] config: Self::Config,
10292
) {
10393
let layout = comptime!(as_cmma_layout(config.matrix_layout(StageIdent::Acc)));
104-
CmmaStageReader::<Self::AccTile>::load_fragment(
105-
tile,
106-
acc,
107-
CubeOption::new_Some(layout),
108-
config.stage_line_size(StageIdent::Acc),
109-
);
94+
CmmaStageReader::<Self::AccTile>::load_fragment(tile, acc, CubeOption::new_Some(layout));
11095
}
11196

11297
fn write_results<E: Numeric>(
@@ -115,8 +100,7 @@ where
115100
#[comptime] _config: Self::Config,
116101
) {
117102
let out = cmma::cast::<A, E>(out);
118-
let line_size = tile.slice.line_size();
119-
CmmaStageWriter::store_fragment(tile, &out, line_size);
103+
CmmaStageWriter::store_fragment(tile, &out);
120104
}
121105

122106
fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment {

crates/cubecl-matmul/src/components/tile/accelerated/reader.rs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ pub(crate) trait CmmaFragmentReader {
1919
tile: &<Self::TileKind as TileKind>::Tile<V>,
2020
fragment: &mut cmma::Matrix<E>,
2121
layout: CubeOption<cmma::MatrixLayout>,
22-
#[comptime] line_size: u32,
2322
);
2423
}
2524

@@ -38,9 +37,8 @@ impl CmmaFragmentReader for CmmaStageReader<Strided> {
3837
tile: &StridedTile<V>,
3938
fragment: &mut cmma::Matrix<E>,
4039
layout: CubeOption<cmma::MatrixLayout>,
41-
#[comptime] line_size: u32,
4240
) {
43-
let (slice, stride) = tile.as_unlined(line_size);
41+
let (slice, stride) = tile.as_unlined();
4442
match layout {
4543
CubeOption::None => cmma::load(fragment, &slice, stride),
4644
CubeOption::Some(layout) => cmma::load_with_layout(fragment, &slice, stride, layout),
@@ -56,7 +54,6 @@ impl CmmaFragmentReader for CmmaStageReader<Filled> {
5654
value: &V,
5755
fragment: &mut cmma::Matrix<E>,
5856
_layout: CubeOption<cmma::MatrixLayout>,
59-
#[comptime] _line_size: u32,
6057
) {
6158
cmma::fill(fragment, E::cast_from(*value));
6259
}
@@ -73,18 +70,14 @@ where
7370
tile: &CubeOption<Inner::Tile<V>>,
7471
fragment: &mut cmma::Matrix<E>,
7572
layout: CubeOption<cmma::MatrixLayout>,
76-
#[comptime] line_size: u32,
7773
) {
7874
match tile {
7975
CubeOption::Some(tile) => {
80-
CmmaStageReader::<Inner>::load_fragment(tile, fragment, layout, line_size)
76+
CmmaStageReader::<Inner>::load_fragment(tile, fragment, layout)
77+
}
78+
CubeOption::None => {
79+
CmmaStageReader::<Filled>::load_fragment::<E, V>(&V::from_int(0), fragment, layout)
8180
}
82-
CubeOption::None => CmmaStageReader::<Filled>::load_fragment::<E, V>(
83-
&V::from_int(0),
84-
fragment,
85-
layout,
86-
line_size,
87-
),
8881
}
8982
}
9083
}

crates/cubecl-matmul/src/components/tile/accelerated/writer.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ impl CmmaStageWriter {
1212
pub fn store_fragment<E: Numeric, V: Numeric>(
1313
tile: &mut StridedTile<V, ReadWrite>,
1414
fragment: &cmma::Matrix<E>,
15-
#[comptime] line_size: u32,
1615
) {
1716
let layout = as_cmma_layout(tile.layout);
18-
let (mut slice, stride) = tile.as_unlined(line_size);
17+
let (mut slice, stride) = tile.as_unlined();
1918
cmma::store(&mut slice, fragment, stride, layout);
2019
}
2120
}

0 commit comments

Comments
 (0)