Skip to content

Commit 0303f39

Browse files
authored
refactor: Matmul inputs (#949)
* Implement tensor map view launching * WIP * Use views for matmul * Fix fusion and TMA * Fix convolution launch * Fix typos
1 parent caacb7b commit 0303f39

File tree

55 files changed

+1360
-2269
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1360
-2269
lines changed

crates/cubecl-attention/src/tests/macros/mod.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,12 @@ use crate::{
88
tests::attention_test_launcher::test_attention_algorithm,
99
};
1010

11+
#[derive(Default)]
1112
pub struct TestOptions {
1213
pub reuse_key_value: bool,
1314
pub two_rows_in_array_tile: bool,
1415
}
1516

16-
impl Default for TestOptions {
17-
fn default() -> Self {
18-
Self {
19-
reuse_key_value: false,
20-
two_rows_in_array_tile: false,
21-
}
22-
}
23-
}
24-
2517
pub fn attention_test_launch<A: Algorithm, R: Runtime>(
2618
client: ComputeClient<R::Server, R::Channel>,
2719
tiling_scheme: AttentionTilingScheme,

crates/cubecl-convolution/src/components/config.rs

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,25 @@ use super::*;
1515
/// Convolution specific config, extends regular matmul [`Config`](global::Config)
1616
pub trait ConvGemmConfig: GlobalConfig {
1717
/// The size of the convolution kernel at `dim`
18-
fn kernel_size(&self, dim: u32) -> u32;
19-
/// The dilation of the kernel at `dim`
20-
fn dilation(&self, dim: u32) -> u32;
21-
/// The stride of the kernel at `dim`
22-
fn stride(&self, dim: u32) -> u32;
23-
/// The padding of the kernel at `dim`
24-
fn padding(&self, dim: u32) -> i32;
25-
/// The dimensionality of the kernel
26-
fn dimensionality(&self) -> Dimensionality;
27-
18+
fn convolution_params(&self) -> ConvolutionParams;
2819
fn line_sizes(&self) -> MatmulLineSizes;
2920
fn check_spatial_bounds(&self) -> bool;
3021
}
3122

3223
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
3324
pub struct ConvolutionConfig<M: GlobalConfig> {
3425
matmul: M,
26+
params: ConvolutionParams,
27+
num_stages: u32,
28+
}
29+
30+
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
31+
pub struct ConvolutionParams {
3532
pub kernel_size: [u32; 3],
3633
pub stride: [u32; 3],
3734
pub dilation: [u32; 3],
3835
pub padding: [i32; 3],
39-
dimensionality: Dimensionality,
40-
num_stages: u32,
36+
pub dimensionality: Dimensionality,
4137
}
4238

4339
impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
@@ -121,24 +117,8 @@ impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
121117
}
122118

123119
impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
124-
fn kernel_size(&self, dim: u32) -> u32 {
125-
self.kernel_size[dim as usize]
126-
}
127-
128-
fn dilation(&self, dim: u32) -> u32 {
129-
self.dilation[dim as usize]
130-
}
131-
132-
fn stride(&self, dim: u32) -> u32 {
133-
self.stride[dim as usize]
134-
}
135-
136-
fn padding(&self, dim: u32) -> i32 {
137-
self.padding[dim as usize]
138-
}
139-
140-
fn dimensionality(&self) -> Dimensionality {
141-
self.dimensionality
120+
fn convolution_params(&self) -> ConvolutionParams {
121+
self.params
142122
}
143123

144124
fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
@@ -150,10 +130,10 @@ impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
150130
}
151131

152132
fn check_spatial_bounds(&self) -> bool {
153-
let spatial_dims = self.dimensionality.num_dims();
133+
let spatial_dims = self.params.dimensionality.num_dims();
154134
let mut has_padding = false;
155135
for i in 0..spatial_dims {
156-
has_padding |= self.padding[i as usize] != 0;
136+
has_padding |= self.params.padding[i as usize] != 0;
157137
}
158138
has_padding
159139
}
@@ -172,20 +152,22 @@ impl<M: GlobalConfig> ConvolutionConfig<M> {
172152
) -> Result<Self, MatmulSetupError> {
173153
let dims = kernel_size.len();
174154

175-
let mut this = Self {
176-
matmul,
155+
let mut params = ConvolutionParams {
177156
kernel_size: [0; 3],
178157
stride: [0; 3],
179158
dilation: [0; 3],
180159
padding: [0; 3],
181160
dimensionality: dim,
182-
num_stages,
183161
};
184-
this.kernel_size[0..dims].copy_from_slice(kernel_size);
185-
this.stride[0..dims].copy_from_slice(stride);
186-
this.dilation[0..dims].copy_from_slice(dilation);
187-
this.padding[0..dims].copy_from_slice(padding);
188-
Ok(this)
162+
params.kernel_size[0..dims].copy_from_slice(kernel_size);
163+
params.stride[0..dims].copy_from_slice(stride);
164+
params.dilation[0..dims].copy_from_slice(dilation);
165+
params.padding[0..dims].copy_from_slice(padding);
166+
Ok(Self {
167+
matmul,
168+
params,
169+
num_stages,
170+
})
189171
}
190172

191173
pub fn to_matmul_config(self) -> M {

crates/cubecl-convolution/src/components/global/args.rs

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,61 +2,156 @@ use std::any::TypeId;
22

33
use cubecl::prelude::*;
44
use cubecl_core as cubecl;
5+
use cubecl_std::{
6+
FastDivmodArgs,
7+
tensor::{
8+
View,
9+
launch::ViewArg,
10+
layout::{
11+
Coords3d,
12+
chain::{Chain, ChainLaunch},
13+
},
14+
},
15+
};
516

617
use crate::{
7-
components::ConvolutionProblem,
18+
components::{
19+
ConvGemmConfig, ConvolutionProblem,
20+
global::{
21+
layout::{
22+
BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
23+
NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
24+
},
25+
read::layout::{
26+
TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
27+
},
28+
},
29+
},
830
kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
931
};
1032
use cubecl_matmul::{
1133
MatmulInputHandleRef,
1234
components::{
13-
MatmulLineSizes, MatmulSelection,
35+
MatmulIdent, MatmulLineSizes, MatmulSelection,
1436
global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch},
1537
},
1638
};
1739

1840
/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
1941
/// output (not fused).
2042
pub trait ConcreteInputsFactory: LaunchArg {
43+
#[allow(clippy::too_many_arguments)]
2144
fn create<'a, R: Runtime>(
45+
client: &ComputeClient<R::Server, R::Channel>,
2246
lhs: &'a MatmulInputHandleRef<'a, R>,
2347
rhs: &'a MatmulInputHandleRef<'a, R>,
2448
bias: Option<&'a TensorHandleRef<'a, R>>,
2549
selection: &MatmulSelection,
2650
problem: &ConvolutionProblem,
2751
line_sizes: &MatmulLineSizes,
52+
config: impl ConvGemmConfig,
53+
) -> Self::RuntimeArg<'a, R>;
54+
}
55+
56+
/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
57+
/// output (not fused).
58+
pub trait ConcreteOutputFactory: LaunchArg {
59+
fn create<'a, R: Runtime>(
60+
client: &ComputeClient<R::Server, R::Channel>,
61+
out: &'a TensorHandleRef<'a, R>,
62+
selection: &MatmulSelection,
63+
problem: &ConvolutionProblem,
64+
line_sizes: &MatmulLineSizes,
65+
config: impl ConvGemmConfig,
2866
) -> Self::RuntimeArg<'a, R>;
2967
}
3068

3169
impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
3270
fn create<'a, R: Runtime>(
71+
client: &ComputeClient<R::Server, R::Channel>,
3372
lhs: &'a MatmulInputHandleRef<'a, R>,
3473
rhs: &'a MatmulInputHandleRef<'a, R>,
3574
bias: Option<&'a TensorHandleRef<'a, R>>,
3675
_selection: &MatmulSelection,
37-
_problem: &ConvolutionProblem,
76+
problem: &ConvolutionProblem,
3877
line_sizes: &MatmulLineSizes,
78+
config: impl ConvGemmConfig,
3979
) -> Self::RuntimeArg<'a, R> {
80+
type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
81+
type RhsLayout = Chain<NhwcLayout, WeightLayout>;
82+
83+
let layout_nhwc = |handle, line_size, check| {
84+
NhwcLayoutLaunch::from_handle(handle, line_size as u32, check)
85+
};
86+
let layout_lhs = Im2colLayoutLaunch::from_args(
87+
client,
88+
problem,
89+
config.convolution_params(),
90+
config.global_memory_config(MatmulIdent::Lhs),
91+
);
92+
let layout_rhs = WeightLayoutLaunch::from_args(
93+
client,
94+
problem,
95+
config.convolution_params(),
96+
config.global_memory_config(MatmulIdent::Rhs),
97+
);
98+
let layout_bias =
99+
BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
100+
101+
let layout_lhs = {
102+
let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
103+
ChainLaunch::new(global, layout_lhs)
104+
};
105+
let layout_rhs = {
106+
let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
107+
ChainLaunch::new(global, layout_rhs)
108+
};
109+
40110
TensorInputsLaunch::new(
41-
lhs.data().as_tensor_arg(line_sizes.lhs),
42-
lhs.scale().map(|it| it.as_tensor_arg(1)).into(),
43-
rhs.data().as_tensor_arg(line_sizes.rhs),
44-
rhs.scale().map(|it| it.as_tensor_arg(1)).into(),
45-
bias.map(|it| it.as_tensor_arg(line_sizes.out)).into(),
111+
ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
112+
ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
113+
bias.map(|bias| {
114+
ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
115+
})
116+
.into(),
46117
)
47118
}
48119
}
49120

121+
impl<EG: Numeric> ConcreteOutputFactory for View<Line<EG>, Coords3d, ReadWrite> {
122+
fn create<'a, R: Runtime>(
123+
client: &ComputeClient<R::Server, R::Channel>,
124+
out: &'a TensorHandleRef<'a, R>,
125+
_selection: &MatmulSelection,
126+
problem: &ConvolutionProblem,
127+
line_sizes: &MatmulLineSizes,
128+
config: impl ConvGemmConfig,
129+
) -> Self::RuntimeArg<'a, R> {
130+
type Layout = Chain<NhwcLayout, OutLayout>;
131+
132+
let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false);
133+
let layout = OutLayoutLaunch::from_args(
134+
client,
135+
problem,
136+
config.global_memory_config(MatmulIdent::Out),
137+
);
138+
let layout = ChainLaunch::new(global, layout);
139+
ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout)
140+
}
141+
}
142+
50143
impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
51144
for TensorMapInputs<Lhs, Rhs, EO>
52145
{
53146
fn create<'a, R: Runtime>(
147+
client: &ComputeClient<R::Server, R::Channel>,
54148
lhs: &'a MatmulInputHandleRef<'a, R>,
55149
rhs: &'a MatmulInputHandleRef<'a, R>,
56150
bias: Option<&'a TensorHandleRef<'a, R>>,
57151
selection: &MatmulSelection,
58152
problem: &ConvolutionProblem,
59153
line_sizes: &MatmulLineSizes,
154+
config: impl ConvGemmConfig,
60155
) -> Self::RuntimeArg<'a, R> {
61156
let tiling_scheme = selection.tiling_scheme;
62157
let stage_m = tiling_scheme.elements_in_stage_m();
@@ -119,9 +214,23 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
119214
)
120215
.with_prefetch(prefetch_rhs);
121216

122-
let bias = bias.map(|it| it.as_tensor_arg(line_sizes.out));
217+
let padded_channels =
218+
(problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k());
219+
220+
// Dummy layout since we don't support im2col loading rn
221+
let lhs_layout = TmaDummyLayoutLaunch::new();
222+
let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels));
123223

124-
// TODO: Think about how to handle scales with TMA
125-
TensorMapInputsLaunch::new(lhs, rhs, bias.into())
224+
let bias = bias.map(|bias| {
225+
let layout =
226+
BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
227+
ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
228+
});
229+
230+
TensorMapInputsLaunch::new(
231+
ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
232+
ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
233+
bias.into(),
234+
)
126235
}
127236
}

crates/cubecl-convolution/src/components/global/base.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use cubecl_matmul::components::{
88
};
99
use cubecl_std::{
1010
CubeOption,
11-
tensor::{layout::Coords2d, r#virtual::VirtualTensor},
11+
tensor::{View, layout::Coords2d},
1212
};
1313

1414
use crate::{
@@ -69,36 +69,28 @@ pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
6969

7070
/// Initializes the global reader for the input feature map with an appropriate layout
7171
fn init_lhs_global_reader(
72-
lhs: VirtualTensor<LhsG<MP>>,
72+
lhs: View<Line<LhsG<MP>>, Coords2d>,
7373
offset: Coords2d,
74-
view_shape: Coords2d,
74+
slice_size: Coords2d,
7575
runtime_args: &RuntimeArgs,
7676
#[comptime] config: Self::Config,
7777
) -> Self::LhsGlobalReader;
7878

7979
/// Initializes the global reader for the weights with an appropriate layout
8080
fn init_rhs_global_reader(
81-
rhs: VirtualTensor<RhsG<MP>>,
82-
offset: Coords2d,
83-
view_shape: Coords2d,
84-
runtime_args: &RuntimeArgs,
81+
rhs: View<Line<RhsG<MP>>, Coords2d>,
8582
#[comptime] config: Self::Config,
8683
) -> Self::RhsGlobalReader;
8784

8885
/// Initializes the global reader for the bias with an appropriate layout
8986
fn init_bias_global_reader(
90-
bias: CubeOption<VirtualTensor<AccG<MP>>>,
91-
n_offset: u32,
92-
slice_size: u32,
87+
bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
9388
#[comptime] config: Self::Config,
9489
) -> Self::AccGlobalReader;
9590

9691
/// Initializes the output feature map global writer with an appropriate layout
9792
fn init_global_writer(
98-
out: VirtualTensor<AccG<MP>, ReadWrite>,
99-
offset: Coords2d,
100-
view_shape: Coords2d,
101-
runtime_args: &RuntimeArgs,
93+
out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
10294
#[comptime] config: Self::Config,
10395
) -> Self::GlobalWriter;
10496

0 commit comments

Comments
 (0)