Skip to content

Commit 290a4c7

Browse files
authored
feat: Matmul global quant (#960)
1 parent 35e917c commit 290a4c7

File tree

25 files changed

+510
-191
lines changed

25 files changed

+510
-191
lines changed

crates/cubecl-common/src/quant/scheme.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ pub struct BlockSize {
224224
}
225225

226226
impl BlockSize {
227+
/// Max number of dimensions for block size
228+
pub const MAX_DIMS: usize = MAX_DIMS;
229+
227230
/// Create a new blocksize from a set of values. The number of values must be `<= MAX_DIMS`.
228231
pub fn new(values: impl AsRef<[u8]>) -> Self {
229232
let values = values.as_ref();

crates/cubecl-convolution/src/components/global/read/reader/bias.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use cubecl_std::{
66
};
77

88
use cubecl_matmul::components::{
9-
MatmulIdent, MatrixPrecision, StageIdent,
9+
MatrixPrecision, StageIdent,
1010
global::GlobalConfig,
1111
stage::{StageMemoryConfig, StridedStage},
1212
};
@@ -33,7 +33,7 @@ impl<IP: MatrixPrecision> BiasGlobalReader<IP> {
3333
pub fn load_stage<G: GlobalConfig>(&mut self, #[comptime] config: G) {
3434
match self {
3535
BiasGlobalReader::Some { view, stage } => {
36-
let line_size = config.global_line_size(MatmulIdent::Out);
36+
let line_size = view.line_size();
3737
let num_stage_elements = config.tiling_scheme().elements_in_stage_n();
3838

3939
let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X;

crates/cubecl-matmul/src/base.rs

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use cubecl_common::quant::scheme::QuantScheme;
12
use cubecl_core::{
23
Runtime,
34
client::ComputeClient,
4-
prelude::{Numeric, TensorHandleRef},
5+
prelude::{CubePrimitive, Numeric, TensorHandleRef},
56
};
67

78
use cubecl_std::tensor::TensorHandle;
@@ -94,33 +95,49 @@ pub enum AsyncReadingStrategy {
9495
Tma,
9596
}
9697

97-
pub enum MatmulInputHandle<R: Runtime, E: Numeric> {
98+
pub enum MatmulInputHandle<R: Runtime, E: CubePrimitive, S: CubePrimitive = f32> {
9899
Normal(TensorHandle<R, E>),
99100
Quantized {
100101
data: TensorHandle<R, E>,
101-
scale: TensorHandle<R, f32>,
102+
scale: TensorHandle<R, S>,
103+
shape: Vec<usize>,
104+
scheme: QuantScheme,
102105
},
103106
}
104107

105108
impl<R: Runtime, E: Numeric> MatmulInputHandle<R, E> {
106109
pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> {
107110
match self {
108111
MatmulInputHandle::Normal(handle) => MatmulInputHandleRef::Normal(handle.as_ref()),
109-
MatmulInputHandle::Quantized { data, scale } => MatmulInputHandleRef::Quantized {
112+
MatmulInputHandle::Quantized {
113+
data,
114+
scale,
115+
shape,
116+
scheme,
117+
} => MatmulInputHandleRef::Quantized {
110118
data: data.as_ref(),
111119
scale: scale.as_ref(),
120+
shape,
121+
scheme,
112122
},
113123
}
114124
}
115125
}
116126

117-
impl<R: Runtime, E: Numeric> Clone for MatmulInputHandle<R, E> {
127+
impl<R: Runtime, E: CubePrimitive> Clone for MatmulInputHandle<R, E> {
118128
fn clone(&self) -> Self {
119129
match self {
120130
Self::Normal(handle) => Self::Normal(handle.clone()),
121-
Self::Quantized { data, scale } => Self::Quantized {
131+
Self::Quantized {
132+
data,
133+
scale,
134+
shape,
135+
scheme,
136+
} => Self::Quantized {
122137
data: data.clone(),
123138
scale: scale.clone(),
139+
shape: shape.clone(),
140+
scheme: *scheme,
124141
},
125142
}
126143
}
@@ -132,6 +149,9 @@ pub enum MatmulInputHandleRef<'a, R: Runtime> {
132149
Quantized {
133150
data: TensorHandleRef<'a, R>,
134151
scale: TensorHandleRef<'a, R>,
152+
/// Unpacked shape, excluding padding
153+
shape: &'a [usize],
154+
scheme: &'a QuantScheme,
135155
},
136156
}
137157

@@ -148,8 +168,18 @@ impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
148168
Self::Normal(data)
149169
}
150170

151-
pub fn quantized(data: TensorHandleRef<'a, R>, scale: TensorHandleRef<'a, R>) -> Self {
152-
Self::Quantized { data, scale }
171+
pub fn quantized(
172+
data: TensorHandleRef<'a, R>,
173+
scale: TensorHandleRef<'a, R>,
174+
shape: &'a [usize],
175+
scheme: &'a QuantScheme,
176+
) -> Self {
177+
Self::Quantized {
178+
data,
179+
scale,
180+
shape,
181+
scheme,
182+
}
153183
}
154184

155185
pub fn data(&self) -> &TensorHandleRef<'a, R> {
@@ -172,6 +202,20 @@ impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
172202
MatmulInputHandleRef::Quantized { scale, .. } => Some(scale),
173203
}
174204
}
205+
206+
pub fn scheme(&self) -> Option<&QuantScheme> {
207+
match self {
208+
MatmulInputHandleRef::Normal(_) => None,
209+
MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme),
210+
}
211+
}
212+
213+
pub fn shape(&self) -> &[usize] {
214+
match self {
215+
MatmulInputHandleRef::Normal(handle) => handle.shape,
216+
MatmulInputHandleRef::Quantized { shape, .. } => shape,
217+
}
218+
}
175219
}
176220

177221
#[allow(clippy::result_large_err)]
@@ -310,7 +354,7 @@ pub fn launch_ref<R: Runtime, MP: MatmulPrecision>(
310354
layered::launch_ref::<R, MP, DoubleUnitAlgorithm>(client, lhs, rhs, out, selection)
311355
}
312356
Strategy::Naive => {
313-
naive::launch_ref::<R, LhsG<MP>, AccG<MP>>(client, lhs.data(), rhs.data(), out)?;
357+
naive::launch_ref::<R, LhsG<MP>, AccG<MP>>(client, lhs, rhs, out)?;
314358
Ok(())
315359
}
316360
Strategy::Auto => {

crates/cubecl-matmul/src/components/global/args.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ use crate::{
1515
global::{
1616
GlobalConfig,
1717
memory::{
18-
BatchedGlobalLayout, BatchedGlobalLayoutLaunch, SimpleTmaGlobalLayout,
19-
SimpleTmaGlobalLayoutLaunch,
18+
BatchedGlobalLayout, BatchedGlobalLayoutLaunch, BatchedGlobalScaleLayout,
19+
SimpleTmaGlobalLayout, SimpleTmaGlobalLayoutLaunch,
2020
},
2121
},
2222
},
@@ -128,19 +128,44 @@ impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
128128
config: impl BatchConfig,
129129
) -> Self::RuntimeArg<'a, R> {
130130
let config = config.global_config();
131-
let view = |handle, ident, line_size| {
132-
let layout = BatchedGlobalLayoutLaunch::from_handle(
133-
client,
134-
handle,
135-
problem,
136-
config.global_memory_config(ident),
137-
);
138-
ViewArg::new::<BatchedGlobalLayout>(handle.as_array_arg(line_size), layout)
131+
let view = |handle: &'a MatmulInputHandleRef<'a, R>, ident, line_size| match handle {
132+
MatmulInputHandleRef::Normal(handle) => {
133+
let layout = BatchedGlobalLayoutLaunch::from_handle(
134+
client,
135+
handle,
136+
problem,
137+
config.global_memory_config(ident),
138+
);
139+
ViewArg::new::<BatchedGlobalLayout>(handle.as_array_arg(line_size), layout)
140+
}
141+
MatmulInputHandleRef::Quantized {
142+
data,
143+
scale,
144+
shape,
145+
scheme,
146+
} => {
147+
let (data_layout, scales_layout) = BatchedGlobalLayoutLaunch::from_quantized_handle(
148+
client,
149+
data,
150+
scale,
151+
shape,
152+
problem,
153+
config.global_memory_config(ident),
154+
**scheme,
155+
);
156+
let data_view =
157+
ViewArg::new::<BatchedGlobalLayout>(data.as_array_arg(line_size), data_layout);
158+
let scales_view = ViewArg::new::<BatchedGlobalScaleLayout>(
159+
scale.as_array_arg(line_size),
160+
scales_layout,
161+
);
162+
ViewArg::new_quantized(data_view, scales_view, **scheme)
163+
}
139164
};
140165

141166
TensorInputsLaunch::new(
142-
view(lhs.data(), MatmulIdent::Lhs, line_sizes.lhs),
143-
view(rhs.data(), MatmulIdent::Rhs, line_sizes.rhs),
167+
view(lhs, MatmulIdent::Lhs, line_sizes.lhs),
168+
view(rhs, MatmulIdent::Rhs, line_sizes.rhs),
144169
CubeOptionArgs::None,
145170
)
146171
}

crates/cubecl-matmul/src/components/global/memory/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{fmt::Debug, hash::Hash};
22

33
use crate::components::MatrixLayout;
44

5-
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
5+
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)]
66
pub struct GlobalMemoryConfig {
77
pub elements_in_tile_row: u32,
88
pub elements_in_tile_col: u32,

0 commit comments

Comments
 (0)