Skip to content

Commit 3b4d937

Browse files
authored
feat: Quantized view (#954)
1 parent a2ba01a commit 3b4d937

File tree

19 files changed

+816
-52
lines changed

19 files changed

+816
-52
lines changed

crates/cubecl-common/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ pub mod reader;
4848
/// Future utils with a compatible API for native, non-std and wasm environments.
4949
pub mod future;
5050

51+
/// Quantization primitives required outside of `cubecl-quant`
52+
pub mod quant;
53+
5154
/// Various utilities to create ID's.
5255
extern crate alloc;
5356

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/// Types representing the quantization scheme
2+
pub mod scheme;

crates/cubecl-quant/src/scheme.rs renamed to crates/cubecl-common/src/quant/scheme.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use alloc::vec;
22
use alloc::vec::Vec;
33
use core::{default::Default, ops::Deref};
4-
use cubecl_common::{e4m3, e5m2};
54
use serde::{Deserialize, Serialize};
65

76
/// Describes a quantization scheme/configuration.
@@ -79,6 +78,12 @@ impl QuantScheme {
7978
pub fn num_quants(&self) -> usize {
8079
self.size_bits_stored() / self.value.size_bits()
8180
}
81+
82+
/// Returns the native packing factor for the values. When native packing > 1, the packed
83+
/// representation stores `num_quants` elements grouped into packs of `native_packing` size.
84+
pub fn native_packing(&self) -> usize {
85+
self.value.native_packing()
86+
}
8287
}
8388

8489
/// Level or granularity of quantization.
@@ -91,6 +96,7 @@ pub enum QuantLevel {
9196
}
9297

9398
impl QuantLevel {
99+
/// Converting constructor for [`QuantLevel::Block`]
94100
pub fn block(values: impl AsRef<[u8]>) -> Self {
95101
QuantLevel::Block(BlockSize::new(values))
96102
}
@@ -129,6 +135,15 @@ impl QuantValue {
129135
}
130136
}
131137

138+
/// Packing factor for the native representation used for intermediate values. If > 1, values
139+
/// should always be processed in `native_packing` sized chunks.
140+
pub fn native_packing(&self) -> usize {
141+
match self {
142+
QuantValue::E2M1 => 2,
143+
_ => 1,
144+
}
145+
}
146+
132147
/// The possible range of values allowed by the quant value.
133148
pub fn range(&self) -> (f32, f32) {
134149
match self {
@@ -138,8 +153,8 @@ impl QuantValue {
138153
QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32),
139154
QuantValue::Q4S => (-7.0, 7.0),
140155
QuantValue::Q2S => (-1.0, 1.0),
141-
QuantValue::E4M3 => (e4m3::MIN as f32, e4m3::MAX as f32),
142-
QuantValue::E5M2 => (e5m2::MIN as f32, e5m2::MAX as f32),
156+
QuantValue::E4M3 => (-448.0, 448.0),
157+
QuantValue::E5M2 => (-57344.0, 57344.0),
143158
QuantValue::E2M1 => (-6.0, 6.0), // Hardcoded because of no-std
144159
}
145160
}
@@ -253,10 +268,12 @@ impl BlockSize {
253268
out
254269
}
255270

271+
/// Create an iterator over all stored dimensions
256272
pub fn iter(&self) -> impl Iterator<Item = &u8> {
257273
self.as_slice().iter()
258274
}
259275

276+
/// Returns the total number of elements in each block
260277
pub fn num_elements(&self) -> usize {
261278
self.iter().map(|it| *it as usize).product()
262279
}

crates/cubecl-cpu/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod tests {
1111
cubecl_core::testgen_all!(f32: [f16, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]);
1212
cubecl_std::testgen!();
1313
cubecl_std::testgen_tensor_identity!([f16, f32, u32]);
14+
cubecl_std::testgen_quantized_view!(f32);
1415
cubecl_random::testgen_random!();
1516
cubecl_matmul::testgen_matmul_simple!([f16, f32]);
1617
cubecl_matmul::testgen_matmul_unit!();

crates/cubecl-cuda/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ mod tests {
8787
// TODO: re-instate matmul quantized tests
8888
cubecl_matmul::testgen_matmul_simple!([f16, bf16, f32]);
8989
cubecl_std::testgen_tensor_identity!([f16, bf16, f32, u32]);
90+
cubecl_std::testgen_quantized_view!(f16);
9091
cubecl_convolution::testgen_conv2d_accelerated!([f16: f16, bf16: bf16, f32: tf32]);
9192
cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]);
9293
cubecl_random::testgen_random!();

crates/cubecl-macros/src/parse/expression.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ fn fn_associated_type(path: &Expression) -> Option<(Path, Option<QSelf>, PathSeg
482482
// All supported primitives. Primitives don't start with an uppercase letter
483483
const PRIMITIVES: &[&str] = &[
484484
"bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f16", "bf16", "f32", "f64",
485-
"flex32", "e2m1", "e2m3", "e3m2", "e4m3", "e5m2", "ue8m0",
485+
"flex32", "e2m1", "e2m1x2", "e2m3", "e3m2", "e4m3", "e5m2", "ue8m0",
486486
];
487487
if !matches!(path, Expression::Path { .. }) {
488488
panic!("path: {path:?}");

crates/cubecl-quant/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub mod quantize;
1111
#[cfg(feature = "kernels")]
1212
pub mod layout;
1313

14-
pub mod scheme;
14+
pub use cubecl_common::quant::scheme;
1515

1616
#[cfg(feature = "export_tests")]
1717
pub mod tests;

crates/cubecl-std/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export_tests = []
1919

2020
[dependencies]
2121

22+
cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false }
2223
cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false }
2324
cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false }
2425
half.workspace = true

crates/cubecl-std/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ pub use fast_math::*;
99
mod option;
1010
pub use option::*;
1111

12+
/// Quantization functionality required in views
13+
pub mod quant;
1214
pub mod tensor;
1315

1416
#[cfg(feature = "export_tests")]
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
use cubecl::prelude::*;
2+
use cubecl_common::quant::scheme::*;
3+
use cubecl_common::{e2m1x2, e4m3, e5m2};
4+
use cubecl_core::{self as cubecl, intrinsic};
5+
6+
/// Dequantize a line of values, where `line_size * num_quants` is a power of two.
7+
/// Unaligned values can't be dequantized in place.
8+
#[cube]
9+
pub fn dequantize_aligned<Q: CubePrimitive, S: CubePrimitive, F: Float>(
10+
value: Line<Q>,
11+
scale: S,
12+
#[comptime] scheme: QuantScheme,
13+
) -> Line<F> {
14+
let q_values = match scheme.store {
15+
QuantStore::Native => Line::<F>::cast_from(value),
16+
QuantStore::U32 => unpack_cast_u32::<F>(Line::cast_from(value), scheme),
17+
};
18+
let scale = Line::<F>::cast_from(scale);
19+
20+
match scheme.mode {
21+
QuantMode::Symmetric => q_values * scale,
22+
}
23+
}
24+
25+
/// Unpack a set of values from u32, and convert to the specified floating point format.
26+
#[cube]
27+
pub fn unpack_cast_u32<F: Float>(value: Line<u32>, #[comptime] scheme: QuantScheme) -> Line<F> {
28+
let num_quants = comptime![scheme.num_quants() as u32];
29+
let native_packing = comptime![scheme.native_packing() as u32];
30+
let out_line_size = comptime![value.line_size() * num_quants];
31+
let size_bits = comptime![scheme.size_bits_value() as u32];
32+
let mask = comptime![packing_mask(scheme)];
33+
34+
let mut out = Line::<F>::empty(out_line_size);
35+
36+
#[unroll]
37+
for line_idx in 0..value.line_size() {
38+
let line_idx = unwrap(line_idx);
39+
let packed_val = value[line_idx];
40+
let out_offset = comptime![line_idx * num_quants];
41+
#[unroll]
42+
for packed_idx in range_stepped(0, num_quants, native_packing) {
43+
let packed_idx = unwrap(packed_idx);
44+
let shift = packed_idx * size_bits;
45+
let value = (packed_val >> shift) & mask;
46+
47+
let float_value = cast_masked::<F>(value, scheme);
48+
49+
#[unroll]
50+
for native_idx in 0..native_packing {
51+
let native_idx = unwrap(native_idx);
52+
let out_offset = comptime![out_offset + packed_idx + native_idx];
53+
out[out_offset] = float_value[native_idx];
54+
}
55+
}
56+
}
57+
58+
out
59+
}
60+
61+
/// The mask required for each packed value, taking into account the native packing required for
62+
/// `e2m1`.
63+
fn packing_mask(scheme: QuantScheme) -> u32 {
64+
let bits = match scheme.value {
65+
QuantValue::E2M1 => 8, // Packed conversion
66+
other => other.size_bits(),
67+
};
68+
(1u32 << bits) - 1
69+
}
70+
71+
/// Cast a masked-out value in the low `n` bits of a `u32` to the specified float type.
72+
/// Applies sign conversion for integer quantization before casting to the float type,
73+
/// while minifloats are simply truncated to `u8`, reinterpreted and then cast.
74+
/// For `e2m1`, casting is done on the packed `e2m1x2` representation.
75+
///
76+
/// # Returns
77+
/// Two floating point numbers for `e2m1`, one for all other formats.
78+
#[cube]
79+
fn cast_masked<F: Float>(value: u32, #[comptime] scheme: QuantScheme) -> Line<F> {
80+
match scheme.value {
81+
// For minifloat we can assume if they're supported then u8 is supported
82+
QuantValue::E5M2 => Line::<F>::cast_from(e5m2::reinterpret(value as u8)),
83+
QuantValue::E4M3 => Line::<F>::cast_from(e4m3::reinterpret(value as u8)),
84+
QuantValue::E2M1 => Line::<F>::cast_from(e2m1x2::reinterpret(value as u8)),
85+
QuantValue::Q8F
86+
| QuantValue::Q4F
87+
| QuantValue::Q2F
88+
| QuantValue::Q8S
89+
| QuantValue::Q4S
90+
| QuantValue::Q2S => {
91+
let size_quant = comptime!(scheme.size_bits_value() as u32);
92+
let sign_bit = comptime!(1u32 << (size_quant - 1));
93+
let two_pow_n = comptime!(1 << size_quant);
94+
95+
// Branchless two's complement conversion
96+
// If raw >= 2^(n-1), then result = raw - 2^n
97+
let raw_i32 = value as i32;
98+
let is_negative = (value >= sign_bit) as i32; // 1 if negative, 0 if positive
99+
let signed_value = raw_i32 - (is_negative * two_pow_n);
100+
Line::<F>::cast_from(signed_value)
101+
}
102+
}
103+
}
104+
105+
#[allow(unused_variables)]
106+
#[cube]
107+
pub(crate) fn unwrap(v: u32) -> comptime_type!(u32) {
108+
intrinsic!(|_| v.constant().expect("Must be constant").as_u32())
109+
}

0 commit comments

Comments
 (0)