Skip to content

Commit dcea7e7

Browse files
authored
refactor: Refactor quantization to use layouts and multi-dimensional block sizes (#938)
1 parent 13ebbd5 commit dcea7e7

File tree

14 files changed

+446
-193
lines changed

14 files changed

+446
-193
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ serial_test = "3.1.1"
6060

6161
bytemuck = "1.16.1"
6262
float4 = "0.1"
63-
float8 = "0.4"
63+
float8 = { version = "0.4", default-features = false }
6464
half = { version = "2.5", features = [
6565
"alloc",
6666
"num-traits",

crates/cubecl-attention/src/components/spec.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::components::{
66
spec::attention_types::*,
77
};
88

9-
/// Attention spec definiting each element types used in the computation as well as
9+
/// Attention spec defining each element types used in the computation as well as
1010
/// how the arguments are passed to the kernel.
1111
pub trait AttentionSpec: Send + Sync + Clone + 'static {
1212
type Precision: AttentionPrecision;

crates/cubecl-common/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ version.workspace = true
1717
cache = ["std", "serde_json", "dirs", "sanitize-filename"]
1818
default = ["std"]
1919
fp4 = ["float4"]
20-
fp8 = ["float8", "float4"]
20+
fp8 = ["float8"]
2121
serde = ["serde_bytes"]
2222
std = ["rand/std", "futures-lite", "rand/thread_rng", "serde_json?/std"]
2323

crates/cubecl-cpp/src/hip/dialect.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for HipDialect<M> {
413413
_rhs: impl Display,
414414
_item: Item<Self>,
415415
) -> std::fmt::Result {
416-
unimplemented!("No native instrution exists, Should be replaced in a preprocessor");
416+
unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
417417
}
418418

419419
fn compile_saturating_sub(
@@ -422,7 +422,7 @@ impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for HipDialect<M> {
422422
_rhs: impl Display,
423423
_item: Item<Self>,
424424
) -> std::fmt::Result {
425-
unimplemented!("No native instrution exists, Should be replaced in a preprocessor");
425+
unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
426426
}
427427

428428
// others

crates/cubecl-quant/Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ version.workspace = true
1212

1313
[features]
1414
default = ["kernels"]
15-
std = ["cubecl-core?/std", "cubecl-runtime?/std"]
16-
kernels = ["std", "cubecl-core", "cubecl-runtime", "cubecl-std"]
1715
export_tests = []
16+
kernels = ["std", "cubecl-core", "cubecl-runtime", "cubecl-std"]
17+
std = ["cubecl-core?/std", "cubecl-runtime?/std"]
1818

1919
[dependencies]
20+
cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false, features = [
21+
"fp8",
22+
] }
2023
cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false, optional = true }
2124
cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, optional = true }
2225
cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false, optional = true }

crates/cubecl-quant/src/dequantize.rs

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#![allow(missing_docs)] // pub cube modules
22

33
use cubecl::prelude::*;
4+
use cubecl_common::{e2m1x2, e4m3, e5m2, ue8m0};
45
use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
56
use cubecl_runtime::TypeUsage;
67

78
use crate::{
8-
qparams::QParams,
9+
layout::{ScalesView, scales_view},
910
scheme::{QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue},
1011
};
1112
use cubecl_std::tensor::{
@@ -16,7 +17,7 @@ use half::{bf16, f16};
1617

1718
/// Dequantize a line of values into floating-point values using the provided scale.
1819
#[cube]
19-
pub fn dequantize_symmetric<F: Float, FS: Float>(value: Line<F>, scale: FS) -> Line<F> {
20+
pub fn dequantize_symmetric<F: Float, FS: CubePrimitive>(value: Line<F>, scale: FS) -> Line<F> {
2021
// x = scale * x_q
2122
Line::cast_from(scale) * value
2223
}
@@ -26,11 +27,11 @@ pub fn dequantize_symmetric<F: Float, FS: Float>(value: Line<F>, scale: FS) -> L
2627
/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
2728
/// values in the stored quantization type.
2829
#[cube]
29-
pub fn dequantize_symmetric_packed_values<F: Float, FS: Float, QI: Int>(
30+
pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
3031
position: u32,
3132
values: &View<Line<QI>, u32>,
32-
scales: &View<Line<FS>, u32>,
33-
#[comptime] scheme: QuantScheme,
33+
scales: &View<FS, u32>,
34+
#[comptime] scheme: &QuantScheme,
3435
) -> Array<Line<F>> {
3536
dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
3637
}
@@ -40,36 +41,34 @@ pub fn dequantize_symmetric_packed_values<F: Float, FS: Float, QI: Int>(
4041
/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
4142
/// values in the stored quantization type.
4243
#[cube]
43-
pub fn dequantize_symmetric_packed_value_at<F: Float, FS: Float, QI: Int>(
44+
pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int>(
4445
position: u32,
4546
values: Line<QI>,
46-
scales: &View<Line<FS>, u32>,
47-
#[comptime] scheme: QuantScheme,
47+
scales: &View<FS, u32>,
48+
#[comptime] scheme: &QuantScheme,
4849
) -> Array<Line<F>> {
49-
let qparams = QParams::new(scheme);
50-
dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, qparams, position, scheme)
50+
dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
5151
}
5252

5353
/// Dequantize a single packed value using the scale provided.
5454
///
5555
/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
5656
/// values in the stored quantization type.
5757
#[cube]
58-
pub fn dequantize_symmetric_packed_value<F: Float, FS: Float, QS: Int>(
58+
pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
5959
values: Line<QS>,
60-
scales: &View<Line<FS>, u32>,
61-
qparams: QParams,
60+
scales: &View<FS, u32>,
6261
position: u32,
63-
#[comptime] scheme: QuantScheme,
62+
#[comptime] scheme: &QuantScheme,
6463
) -> Array<Line<F>> {
6564
let line_size_values = values.line_size();
66-
let num_quants = comptime!(qparams.num_quants);
65+
let num_quants = comptime!(scheme.num_quants() as u32);
6766
let mut tmp = Array::vectorized(line_size_values, num_quants);
6867

6968
#[unroll]
7069
for i in 0..line_size_values {
7170
let floats = unpack_q::<F, QS>(values[i], scheme.value, scheme.store);
72-
let scale = qparams.scale(scales, (position * line_size_values) + i);
71+
let scale = scales[(position * line_size_values) + i * num_quants];
7372
let values = dequantize_symmetric::<F, FS>(floats, scale);
7473
tmp[i] = values;
7574
}
@@ -117,33 +116,27 @@ fn unpack_q<F: Float, QS: Int>(
117116
}
118117

119118
#[cube(launch_unchecked)]
120-
fn dequantize_symmetric_packed_kernel<F: Float, FS: Float>(
119+
fn dequantize_symmetric_packed_kernel<F: Float, FS: CubePrimitive>(
121120
input: &LinearView<Line<u32>>,
122-
scales: &LinearView<Line<FS>>,
121+
scales: &ScalesView<FS>,
123122
output: &mut LinearView<Line<F>, ReadWrite>,
124-
#[comptime] scheme: QuantScheme,
123+
#[comptime] scheme: &QuantScheme,
125124
) {
126125
if !input.is_in_bounds(ABSOLUTE_POS) {
127126
terminate!();
128127
}
129128

130-
let qparams = QParams::new(scheme);
131129
let line_size_in = input.line_size();
132130
let line_size_out = output.line_size();
133131

134132
comptime! {
135-
assert_eq!(line_size_out, qparams.num_quants);
133+
assert_eq!(line_size_out, scheme.num_quants() as u32);
136134
}
137135

138136
let values = input[ABSOLUTE_POS];
137+
let packed_pos = ABSOLUTE_POS * comptime![scheme.num_quants() as u32];
139138

140-
let out = dequantize_symmetric_packed_value::<F, FS, u32>(
141-
values,
142-
scales,
143-
qparams,
144-
ABSOLUTE_POS,
145-
scheme,
146-
);
139+
let out = dequantize_symmetric_packed_value::<F, FS, u32>(values, scales, packed_pos, scheme);
147140

148141
#[unroll]
149142
for i in 0..line_size_in {
@@ -152,19 +145,18 @@ fn dequantize_symmetric_packed_kernel<F: Float, FS: Float>(
152145
}
153146

154147
#[cube(launch_unchecked)]
155-
fn dequantize_symmetric_int8_native_kernel<F: Float, FS: Float>(
156-
input: &LinearView<Line<i8>>,
157-
scale: &LinearView<Line<FS>>,
148+
fn dequantize_symmetric_native_kernel<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
149+
input: &LinearView<Line<Q>>,
150+
scale: &ScalesView<FS>,
158151
output: &mut LinearView<Line<F>, ReadWrite>,
159-
#[comptime] scheme: QuantScheme,
160152
) {
161153
if !input.is_in_bounds(ABSOLUTE_POS) {
162154
terminate!();
163155
}
164156

165-
let qparams = QParams::new(scheme);
157+
let native_packing = Q::packing_factor();
166158
// Absolute pos represents the logical block (scale) used to dequantize, not layout
167-
let scale = qparams.scale(scale, ABSOLUTE_POS * input.line_size());
159+
let scale = scale[ABSOLUTE_POS * input.line_size() * native_packing];
168160

169161
output[ABSOLUTE_POS] =
170162
dequantize_symmetric::<F, FS>(Line::cast_from(input[ABSOLUTE_POS]), scale);
@@ -193,9 +185,20 @@ pub fn launch_ref<R: Runtime, F: Float>(
193185
QuantParam::BF16 => {
194186
dequantize_packed::<R, F, bf16>(client, values, scheme, params, output)
195187
}
188+
QuantParam::UE8M0 => {
189+
dequantize_packed::<R, F, ue8m0>(client, values, scheme, params, output)
190+
}
191+
QuantParam::UE4M3 => {
192+
dequantize_packed::<R, F, e4m3>(client, values, scheme, params, output)
193+
}
196194
},
197195
QuantScheme {
198-
value: QuantValue::Q8F | QuantValue::Q8S,
196+
value:
197+
QuantValue::Q8F
198+
| QuantValue::Q8S
199+
| QuantValue::E4M3
200+
| QuantValue::E5M2
201+
| QuantValue::E2M1,
199202
store: QuantStore::Native,
200203
..
201204
} => {
@@ -216,6 +219,12 @@ pub fn launch_ref<R: Runtime, F: Float>(
216219
QuantParam::BF16 => {
217220
dequantize_native::<R, F, bf16>(client, values, scheme, params, output)
218221
}
222+
QuantParam::UE8M0 => {
223+
dequantize_native::<R, F, ue8m0>(client, values, scheme, params, output)
224+
}
225+
QuantParam::UE4M3 => {
226+
dequantize_native::<R, F, e4m3>(client, values, scheme, params, output)
227+
}
219228
}
220229
}
221230
QuantScheme {
@@ -228,7 +237,7 @@ pub fn launch_ref<R: Runtime, F: Float>(
228237
}
229238
}
230239

231-
fn dequantize_packed<R: Runtime, F: Float, FS: Float>(
240+
fn dequantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
232241
client: &ComputeClient<R::Server, R::Channel>,
233242
input: &TensorHandleRef<R>,
234243
scheme: &QuantScheme,
@@ -268,17 +277,17 @@ fn dequantize_packed<R: Runtime, F: Float, FS: Float>(
268277
cube_count,
269278
cube_dim,
270279
linear_view(client, input, &line_size_in),
271-
linear_view(client, scale, &1),
280+
scales_view(client, input, scale, &1, scheme),
272281
linear_view(client, output, &line_size_out),
273-
*scheme,
282+
scheme.clone(),
274283
)
275284
};
276285
}
277286
QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
278287
}
279288
}
280289

281-
fn dequantize_native<R: Runtime, F: Float, FS: Float>(
290+
fn dequantize_native<R: Runtime, F: Float, FS: CubePrimitive>(
282291
client: &ComputeClient<R::Server, R::Channel>,
283292
input: &TensorHandleRef<R>,
284293
scheme: &QuantScheme,
@@ -299,19 +308,34 @@ fn dequantize_native<R: Runtime, F: Float, FS: Float>(
299308
QuantScheme {
300309
level: QuantLevel::Tensor | QuantLevel::Block(_),
301310
mode: QuantMode::Symmetric,
302-
value: QuantValue::Q8F | QuantValue::Q8S,
311+
value,
303312
store: QuantStore::Native,
304313
..
305314
} => {
315+
let launch = match value {
316+
QuantValue::Q8F | QuantValue::Q8S => {
317+
dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, i8, R>
318+
}
319+
QuantValue::E4M3 => {
320+
dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e4m3, R>
321+
}
322+
QuantValue::E5M2 => {
323+
dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e5m2, R>
324+
}
325+
QuantValue::E2M1 => {
326+
dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e2m1x2, R>
327+
}
328+
other => panic!("Unsupported quantization value {other:?}"),
329+
};
330+
306331
unsafe {
307-
dequantize_symmetric_int8_native_kernel::launch_unchecked::<F, FS, R>(
332+
launch(
308333
client,
309334
cube_count,
310335
cube_dim,
311336
linear_view(client, input, &line_size),
312-
linear_view(client, scale, &1),
337+
scales_view(client, input, scale, &1, scheme),
313338
linear_view(client, output, &line_size),
314-
*scheme,
315339
)
316340
};
317341
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod scales;
2+
3+
pub use scales::*;

0 commit comments

Comments
 (0)