Skip to content

Commit da0e292

Browse files
authored
feat: TMA views (#943)
1 parent 31818e5 commit da0e292

40 files changed

+1826
-1117
lines changed

crates/cubecl-attention/src/components/global/dummy/read.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub struct QueryReader<AP: AttentionPrecision> {
2323

2424
#[derive(CubeType)]
2525
pub struct DummyKeyReader<AP: AttentionPrecision, G: GlobalAttentionConfig> {
26-
global_iter: GlobalIterator<KG<AP>>,
26+
global_iter: GlobalIterator<Line<KG<AP>>>,
2727
stage_memory: StridedStage<KS<AP>, AttentionTilingLayout>,
2828

2929
#[cube(comptime)]
@@ -32,7 +32,7 @@ pub struct DummyKeyReader<AP: AttentionPrecision, G: GlobalAttentionConfig> {
3232

3333
#[derive(CubeType)]
3434
pub struct DummyValueReader<AP: AttentionPrecision, G: GlobalAttentionConfig> {
35-
global_iter: GlobalIterator<VG<AP>>,
35+
global_iter: GlobalIterator<Line<VG<AP>>>,
3636
stage_memory: StridedStage<VS<AP>, AttentionTilingLayout>,
3737

3838
#[cube(comptime)]

crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use cubecl_matmul::components::{
1515
};
1616
use cubecl_std::{
1717
CubeOption,
18-
tensor::{layout::Coords2d, r#virtual::VirtualTensor},
18+
tensor::{AsTensorView, AsTensorViewExpand, layout::Coords2d, r#virtual::VirtualTensor},
1919
};
2020

2121
use crate::{
@@ -27,6 +27,7 @@ use crate::{
2727
read::{
2828
bias::{BiasGlobalReader, BiasStage},
2929
im2col_tma::{TmaIm2colGlobalReader, TmaIm2colTiling},
30+
layout::TmaWeightLayout,
3031
weight_tma::{TmaWeightGlobalReader, TmaWeightTiling},
3132
},
3233
},
@@ -69,7 +70,7 @@ where
6970
type Config = ConvolutionConfig<SimpleTmaConfig<SMM::Config>>;
7071

7172
type LhsGlobalReader = TmaIm2colGlobalReader<MP::Lhs, Self::Config>;
72-
type RhsGlobalReader = TmaWeightGlobalReader<MP::Rhs, SMM::Config>;
73+
type RhsGlobalReader = TmaWeightGlobalReader<MP::Rhs>;
7374
type AccGlobalReader = BiasGlobalReader<MP::Acc>;
7475
type GlobalWriter = PlaneWriter<MP::Acc>;
7576

@@ -121,12 +122,12 @@ where
121122
let barrier = Barrier::new_with_tma_proxy(BarrierLevel::cube_coop(0u32));
122123

123124
lhs_reader.fill_stage(&barrier, stage);
124-
rhs_reader.fill_stage(&barrier, stage, stage_config);
125+
rhs_reader.fill_stage(&barrier, stage);
125126

126127
arrive_tma(&barrier, stages_bytes);
127128

128129
lhs_reader.advance_view(k_step);
129-
rhs_reader.advance_view(k_step);
130+
rhs_reader.advance_view();
130131

131132
barriers.push(barrier);
132133

@@ -168,12 +169,12 @@ where
168169

169170
// Refill stage and advance view
170171
lhs_reader.fill_stage(barrier, stage);
171-
rhs_reader.fill_stage(barrier, stage, stage_config);
172+
rhs_reader.fill_stage(barrier, stage);
172173

173174
arrive_tma(barrier, stages_bytes);
174175

175176
lhs_reader.advance_view(k_step);
176-
rhs_reader.advance_view(k_step);
177+
rhs_reader.advance_view();
177178
}
178179
}
179180

@@ -216,16 +217,15 @@ where
216217
fn init_rhs_global_reader(
217218
rhs: VirtualTensor<RhsG<MP>>,
218219
offset: Coords2d,
219-
_slice_size: Coords2d,
220+
slice_size: Coords2d,
220221
runtime_args: &RuntimeArgs,
221222
#[comptime] config: Self::Config,
222223
) -> Self::RhsGlobalReader {
223-
let (x_offset, y_offset) = offset;
224+
let layout = TmaWeightLayout::new(runtime_args.padded_channels);
225+
let rhs = rhs.as_tensor_map().unwrap().view_3d(layout);
224226
Self::RhsGlobalReader::new(
225-
rhs.as_tensor_map().unwrap(),
226-
x_offset,
227-
y_offset,
228-
runtime_args,
227+
rhs.slice(offset, slice_size),
228+
config.k_step,
229229
config.num_stages(MatmulIdent::Rhs),
230230
config.stage_memory_config(MatmulIdent::Rhs),
231231
)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use cubecl::prelude::*;
2+
use cubecl_core as cubecl;
3+
use cubecl_std::{
4+
FastDivmod,
5+
tensor::layout::{Coords2d, Coords3d, Layout, LayoutExpand},
6+
};
7+
8+
#[derive(CubeType)]
9+
pub struct TmaWeightLayout {
10+
padded_channels: FastDivmod,
11+
}
12+
13+
#[cube]
14+
impl TmaWeightLayout {
15+
pub fn new(padded_channels: FastDivmod) -> Self {
16+
TmaWeightLayout { padded_channels }
17+
}
18+
}
19+
20+
#[cube]
21+
impl Layout for TmaWeightLayout {
22+
type Coordinates = Coords2d;
23+
type SourceCoordinates = Coords3d;
24+
25+
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26+
let (k, n) = pos;
27+
let (k_idx, in_c) = self.padded_channels.div_mod(k);
28+
(n, k_idx, in_c)
29+
}
30+
31+
fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
32+
true.runtime()
33+
}
34+
35+
fn shape(&self) -> Self::Coordinates {
36+
(u32::MAX, u32::MAX).runtime()
37+
}
38+
39+
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
40+
(self.to_source_pos(pos), self.is_in_bounds(pos))
41+
}
42+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod bias;
22
pub mod im2col_tma;
3+
pub mod layout;
34
pub mod weight_tma;
Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,31 @@
1-
use core::marker::PhantomData;
2-
31
use cubecl_core::prelude::*;
42
use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
53
use cubecl_matmul::components::{
6-
MatmulIdent, MatrixPrecision, StageIdent, stage::StageMemoryConfig,
4+
MatrixPrecision, StageIdent,
5+
global::memory::{GlobalIterator, ViewDirection},
6+
stage::StageMemoryConfig,
77
};
8-
use cubecl_std::FastDivmod;
8+
use cubecl_std::tensor::{View, layout::Coords2d};
99

1010
use cubecl_matmul::components::stage::RowMajorTilingOrder;
11-
use cubecl_matmul::components::{
12-
global::memory::MappedTensorReader,
13-
stage::{ContiguousTilingLayout, StageConfig, StridedStage},
14-
};
15-
16-
use crate::kernels::layered::selector::RuntimeArgs;
11+
use cubecl_matmul::components::stage::{ContiguousTilingLayout, StridedStage};
1712

1813
pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
1914
pub type TmaWeightStage<IP> = StridedStage<<IP as MatrixPrecision>::Stage, TmaWeightTiling>;
2015

2116
#[derive(CubeType)]
22-
pub struct TmaWeightGlobalReader<IP: MatrixPrecision, S: StageConfig> {
23-
pub tensor_view: MappedTensorReader<IP::Global>,
17+
pub struct TmaWeightGlobalReader<IP: MatrixPrecision> {
18+
pub global_iter: GlobalIterator<IP::Global>,
2419
pub stages: Sequence<StridedStage<IP::Stage, TmaWeightTiling>>,
25-
padded_channels: FastDivmod,
2620
#[cube(comptime)]
27-
_config: PhantomData<S>,
21+
config: StageMemoryConfig,
2822
}
2923

3024
#[cube]
31-
impl<IP: MatrixPrecision, S: StageConfig> TmaWeightGlobalReader<IP, S> {
25+
impl<IP: MatrixPrecision> TmaWeightGlobalReader<IP> {
3226
pub fn new(
33-
tensor: TensorMap<IP::Global>,
34-
x: u32,
35-
y: u32,
36-
runtime_args: &RuntimeArgs,
27+
global_view: View<IP::Global, Coords2d>,
28+
k_step: u32,
3729
#[comptime] num_stages: u32,
3830
#[comptime] config: StageMemoryConfig,
3931
) -> Self {
@@ -44,42 +36,32 @@ impl<IP: MatrixPrecision, S: StageConfig> TmaWeightGlobalReader<IP, S> {
4436
stages.push(StridedStage::new_aligned(StageIdent::Rhs, 128u32, config));
4537
}
4638

47-
let tensor_view = MappedTensorReader::new(tensor, x, y, 0);
39+
let global_iter = GlobalIterator::new(global_view, k_step, ViewDirection::Row, false);
4840

49-
TmaWeightGlobalReader::<IP, S> {
50-
tensor_view,
41+
TmaWeightGlobalReader::<IP> {
42+
global_iter,
5143
stages,
52-
padded_channels: runtime_args.padded_channels,
53-
_config: PhantomData::<S>,
44+
config,
5445
}
5546
}
5647

57-
pub fn fill_stage(
58-
&mut self,
59-
barrier: &Barrier,
60-
#[comptime] stage_idx: u32,
61-
#[comptime] config: S,
62-
) {
48+
pub fn fill_stage(&mut self, barrier: &Barrier, #[comptime] stage_idx: u32) {
6349
let stage = self.stages.index_mut(stage_idx);
50+
let config = comptime![self.config];
6451

6552
if UNIT_POS == 0 {
66-
let k = self.tensor_view.tile_x;
67-
let out_c = self.tensor_view.tile_y;
53+
let global_view = self.global_iter.view();
6854

69-
let tensor = self.tensor_view.tensor.try_cast_unchecked();
7055
let mut stage = stage.as_slice_mut(1u32);
71-
let slice_size = config.tiling_scheme().elements_in_stage_n()
72-
* config.tiling_scheme().elements_in_tile_k();
56+
let slice_size = config.elements_in_stage_col() * config.elements_in_tile_row;
7357

7458
#[unroll]
75-
for tile_k in 0..config.tiling_scheme().tiles_in_stage_k() {
59+
for tile_k in 0..config.tiles_in_stage_row {
7660
let slice_start = slice_size * tile_k;
77-
let mut slice = stage.slice_mut(slice_start, slice_size);
78-
79-
let k = k + tile_k * config.tiling_scheme().elements_in_tile_k();
80-
let (k_idx, in_c) = self.padded_channels.div_mod(k);
61+
let slice = stage.slice_mut(slice_start, slice_size);
8162

82-
barrier.tma_load_3d(&tensor, &mut slice, out_c as i32, k_idx as i32, in_c as i32);
63+
let k = tile_k * config.elements_in_tile_row;
64+
global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), (k, 0));
8365
}
8466
}
8567
}
@@ -88,7 +70,7 @@ impl<IP: MatrixPrecision, S: StageConfig> TmaWeightGlobalReader<IP, S> {
8870
*self.stages.index(stage_idx)
8971
}
9072

91-
pub fn advance_view(&mut self, k_offset: u32) {
92-
self.tensor_view.update_view(k_offset, MatmulIdent::Rhs);
73+
pub fn advance_view(&mut self) {
74+
self.global_iter.advance();
9375
}
9476
}

crates/cubecl-convolution/src/components/global/single_stage/tma/convolution.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use cubecl_matmul::components::{
1515
};
1616
use cubecl_std::{
1717
CubeOption,
18-
tensor::{layout::Coords2d, r#virtual::VirtualTensor},
18+
tensor::{AsTensorView, AsTensorViewExpand, layout::Coords2d, r#virtual::VirtualTensor},
1919
};
2020

2121
use crate::{
@@ -27,6 +27,7 @@ use crate::{
2727
read::{
2828
bias::{BiasGlobalReader, BiasStage},
2929
im2col_tma::{TmaIm2colGlobalReader, TmaIm2colTiling},
30+
layout::TmaWeightLayout,
3031
weight_tma::{TmaWeightGlobalReader, TmaWeightTiling},
3132
},
3233
},
@@ -56,7 +57,7 @@ where
5657
type Config = ConvolutionConfig<SimpleTmaConfig<SMM::Config>>;
5758

5859
type LhsGlobalReader = TmaIm2colGlobalReader<MP::Lhs, Self::Config>;
59-
type RhsGlobalReader = TmaWeightGlobalReader<MP::Rhs, SMM::Config>;
60+
type RhsGlobalReader = TmaWeightGlobalReader<MP::Rhs>;
6061
type AccGlobalReader = BiasGlobalReader<MP::Acc>;
6162
type GlobalWriter = PlaneWriter<MP::Acc>;
6263

@@ -97,7 +98,7 @@ where
9798
sync_cube();
9899

99100
lhs_reader.fill_stage(&barrier, 0u32);
100-
rhs_reader.fill_stage(&barrier, 0u32, config.stage_config());
101+
rhs_reader.fill_stage(&barrier, 0u32);
101102

102103
arrive_tma(&barrier, stages_bytes);
103104

@@ -114,7 +115,7 @@ where
114115
);
115116

116117
lhs_reader.advance_view(k_step);
117-
rhs_reader.advance_view(k_step);
118+
rhs_reader.advance_view();
118119
}
119120

120121
sync_cube();
@@ -145,16 +146,15 @@ where
145146
fn init_rhs_global_reader(
146147
rhs: VirtualTensor<RhsG<MP>>,
147148
offset: Coords2d,
148-
_slice_size: Coords2d,
149+
slice_size: Coords2d,
149150
runtime_args: &RuntimeArgs,
150151
#[comptime] config: Self::Config,
151152
) -> Self::RhsGlobalReader {
152-
let (x_offset, y_offset) = offset;
153+
let layout = TmaWeightLayout::new(runtime_args.padded_channels);
154+
let rhs = rhs.as_tensor_map().unwrap().view_3d(layout);
153155
Self::RhsGlobalReader::new(
154-
rhs.as_tensor_map().unwrap(),
155-
x_offset,
156-
y_offset,
157-
runtime_args,
156+
rhs.slice(offset, slice_size),
157+
config.k_step,
158158
1u32,
159159
config.stage_memory_config(MatmulIdent::Rhs),
160160
)

crates/cubecl-matmul/src/components/global/memory/iterator.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use cubecl_std::tensor::{View, layout::Coords2d};
44

55
#[derive(Clone, CubeType)]
66
/// An iterator over global memory, advancing along k.
7-
pub struct GlobalIterator<EI: Numeric> {
8-
global_view: View<Line<EI>, Coords2d>,
7+
pub struct GlobalIterator<EI: CubePrimitive> {
8+
global_view: View<EI, Coords2d>,
99
offset: RuntimeCell<u32>,
1010
/// The amount to advance by on each iteration
1111
step: u32,
@@ -16,8 +16,8 @@ pub struct GlobalIterator<EI: Numeric> {
1616
checked: bool,
1717
}
1818

19-
unsafe impl<EG: Numeric> Sync for GlobalIterator<EG> {}
20-
unsafe impl<EG: Numeric> Send for GlobalIterator<EG> {}
19+
unsafe impl<EG: CubePrimitive> Sync for GlobalIterator<EG> {}
20+
unsafe impl<EG: CubePrimitive> Send for GlobalIterator<EG> {}
2121

2222
#[derive(CubeType, Clone, Copy)]
2323
pub enum ViewDirection {
@@ -28,14 +28,14 @@ pub enum ViewDirection {
2828
}
2929

3030
#[cube]
31-
impl<EG: Numeric> GlobalIterator<EG> {
31+
impl<EG: CubePrimitive> GlobalIterator<EG> {
3232
/// Instantiate a read iterator over the given global view, which should be sliced to the size
3333
/// of one `m`/`n` stage and the full range of `k` handled by this matmul instance.
3434
///
3535
/// `step` is the amount advanced in `view_direction` each iteration.
3636
/// `checked` determines whether the slices should be created as checked or unchecked.
3737
pub fn new(
38-
global_view: View<Line<EG>, Coords2d>,
38+
global_view: View<EG, Coords2d>,
3939
step: u32,
4040
#[comptime] view_direction: ViewDirection,
4141
#[comptime] checked: bool,
@@ -63,7 +63,7 @@ impl<EG: Numeric> GlobalIterator<EG> {
6363
}
6464

6565
/// Returns the current view slice of the iterator
66-
pub fn view(&self) -> View<Line<EG>, Coords2d> {
66+
pub fn view(&self) -> View<EG, Coords2d> {
6767
let offset = match comptime![self.view_direction] {
6868
ViewDirection::Row => (self.offset.read(), 0u32),
6969
ViewDirection::Col => (0u32, self.offset.read()),

0 commit comments

Comments
 (0)