11#![ allow( missing_docs) ] // pub cube modules
22
33use cubecl:: prelude:: * ;
4+ use cubecl_common:: { e2m1x2, e4m3, e5m2, ue8m0} ;
45use cubecl_core:: { self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel} ;
56use cubecl_runtime:: TypeUsage ;
67
78use crate :: {
8- qparams :: QParams ,
9+ layout :: { ScalesView , scales_view } ,
910 scheme:: { QuantLevel , QuantMode , QuantParam , QuantScheme , QuantStore , QuantValue } ,
1011} ;
1112use cubecl_std:: tensor:: {
@@ -16,7 +17,7 @@ use half::{bf16, f16};
1617
1718/// Dequantize a line of values into floating-point values using the provided scale.
1819#[ cube]
19- pub fn dequantize_symmetric < F : Float , FS : Float > ( value : Line < F > , scale : FS ) -> Line < F > {
20+ pub fn dequantize_symmetric < F : Float , FS : CubePrimitive > ( value : Line < F > , scale : FS ) -> Line < F > {
2021 // x = scale * x_q
2122 Line :: cast_from ( scale) * value
2223}
@@ -26,11 +27,11 @@ pub fn dequantize_symmetric<F: Float, FS: Float>(value: Line<F>, scale: FS) -> L
2627/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
2728/// values in the stored quantization type.
2829#[ cube]
29- pub fn dequantize_symmetric_packed_values < F : Float , FS : Float , QI : Int > (
30+ pub fn dequantize_symmetric_packed_values < F : Float , FS : CubePrimitive , QI : Int > (
3031 position : u32 ,
3132 values : & View < Line < QI > , u32 > ,
32- scales : & View < Line < FS > , u32 > ,
33- #[ comptime] scheme : QuantScheme ,
33+ scales : & View < FS , u32 > ,
34+ #[ comptime] scheme : & QuantScheme ,
3435) -> Array < Line < F > > {
3536 dequantize_symmetric_packed_value_at :: < F , FS , QI > ( position, values[ position] , scales, scheme)
3637}
@@ -40,36 +41,34 @@ pub fn dequantize_symmetric_packed_values<F: Float, FS: Float, QI: Int>(
4041/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
4142/// values in the stored quantization type.
4243#[ cube]
43- pub fn dequantize_symmetric_packed_value_at < F : Float , FS : Float , QI : Int > (
44+ pub fn dequantize_symmetric_packed_value_at < F : Float , FS : CubePrimitive , QI : Int > (
4445 position : u32 ,
4546 values : Line < QI > ,
46- scales : & View < Line < FS > , u32 > ,
47- #[ comptime] scheme : QuantScheme ,
47+ scales : & View < FS , u32 > ,
48+ #[ comptime] scheme : & QuantScheme ,
4849) -> Array < Line < F > > {
49- let qparams = QParams :: new ( scheme) ;
50- dequantize_symmetric_packed_value :: < F , FS , QI > ( values, scales, qparams, position, scheme)
50+ dequantize_symmetric_packed_value :: < F , FS , QI > ( values, scales, position, scheme)
5151}
5252
5353/// Dequantize a single packed value using the scale provided.
5454///
5555/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
5656/// values in the stored quantization type.
5757#[ cube]
58- pub fn dequantize_symmetric_packed_value < F : Float , FS : Float , QS : Int > (
58+ pub fn dequantize_symmetric_packed_value < F : Float , FS : CubePrimitive , QS : Int > (
5959 values : Line < QS > ,
60- scales : & View < Line < FS > , u32 > ,
61- qparams : QParams ,
60+ scales : & View < FS , u32 > ,
6261 position : u32 ,
63- #[ comptime] scheme : QuantScheme ,
62+ #[ comptime] scheme : & QuantScheme ,
6463) -> Array < Line < F > > {
6564 let line_size_values = values. line_size ( ) ;
66- let num_quants = comptime ! ( qparams . num_quants) ;
65+ let num_quants = comptime ! ( scheme . num_quants( ) as u32 ) ;
6766 let mut tmp = Array :: vectorized ( line_size_values, num_quants) ;
6867
6968 #[ unroll]
7069 for i in 0 ..line_size_values {
7170 let floats = unpack_q :: < F , QS > ( values[ i] , scheme. value , scheme. store ) ;
72- let scale = qparams . scale ( scales, ( position * line_size_values) + i) ;
71+ let scale = scales[ ( position * line_size_values) + i * num_quants ] ;
7372 let values = dequantize_symmetric :: < F , FS > ( floats, scale) ;
7473 tmp[ i] = values;
7574 }
@@ -117,33 +116,27 @@ fn unpack_q<F: Float, QS: Int>(
117116}
118117
119118#[ cube( launch_unchecked) ]
120- fn dequantize_symmetric_packed_kernel < F : Float , FS : Float > (
119+ fn dequantize_symmetric_packed_kernel < F : Float , FS : CubePrimitive > (
121120 input : & LinearView < Line < u32 > > ,
122- scales : & LinearView < Line < FS > > ,
121+ scales : & ScalesView < FS > ,
123122 output : & mut LinearView < Line < F > , ReadWrite > ,
124- #[ comptime] scheme : QuantScheme ,
123+ #[ comptime] scheme : & QuantScheme ,
125124) {
126125 if !input. is_in_bounds ( ABSOLUTE_POS ) {
127126 terminate ! ( ) ;
128127 }
129128
130- let qparams = QParams :: new ( scheme) ;
131129 let line_size_in = input. line_size ( ) ;
132130 let line_size_out = output. line_size ( ) ;
133131
134132 comptime ! {
135- assert_eq!( line_size_out, qparams . num_quants) ;
133+ assert_eq!( line_size_out, scheme . num_quants( ) as u32 ) ;
136134 }
137135
138136 let values = input[ ABSOLUTE_POS ] ;
137+ let packed_pos = ABSOLUTE_POS * comptime ! [ scheme. num_quants( ) as u32 ] ;
139138
140- let out = dequantize_symmetric_packed_value :: < F , FS , u32 > (
141- values,
142- scales,
143- qparams,
144- ABSOLUTE_POS ,
145- scheme,
146- ) ;
139+ let out = dequantize_symmetric_packed_value :: < F , FS , u32 > ( values, scales, packed_pos, scheme) ;
147140
148141 #[ unroll]
149142 for i in 0 ..line_size_in {
@@ -152,19 +145,18 @@ fn dequantize_symmetric_packed_kernel<F: Float, FS: Float>(
152145}
153146
154147#[ cube( launch_unchecked) ]
155- fn dequantize_symmetric_int8_native_kernel < F : Float , FS : Float > (
156- input : & LinearView < Line < i8 > > ,
157- scale : & LinearView < Line < FS > > ,
148+ fn dequantize_symmetric_native_kernel < F : Float , FS : CubePrimitive , Q : CubePrimitive > (
149+ input : & LinearView < Line < Q > > ,
150+ scale : & ScalesView < FS > ,
158151 output : & mut LinearView < Line < F > , ReadWrite > ,
159- #[ comptime] scheme : QuantScheme ,
160152) {
161153 if !input. is_in_bounds ( ABSOLUTE_POS ) {
162154 terminate ! ( ) ;
163155 }
164156
165- let qparams = QParams :: new ( scheme ) ;
157+ let native_packing = Q :: packing_factor ( ) ;
166158 // Absolute pos represents the logical block (scale) used to dequantize, not layout
167- let scale = qparams . scale ( scale , ABSOLUTE_POS * input. line_size ( ) ) ;
159+ let scale = scale[ ABSOLUTE_POS * input. line_size ( ) * native_packing ] ;
168160
169161 output[ ABSOLUTE_POS ] =
170162 dequantize_symmetric :: < F , FS > ( Line :: cast_from ( input[ ABSOLUTE_POS ] ) , scale) ;
@@ -193,9 +185,20 @@ pub fn launch_ref<R: Runtime, F: Float>(
193185 QuantParam :: BF16 => {
194186 dequantize_packed :: < R , F , bf16 > ( client, values, scheme, params, output)
195187 }
188+ QuantParam :: UE8M0 => {
189+ dequantize_packed :: < R , F , ue8m0 > ( client, values, scheme, params, output)
190+ }
191+ QuantParam :: UE4M3 => {
192+ dequantize_packed :: < R , F , e4m3 > ( client, values, scheme, params, output)
193+ }
196194 } ,
197195 QuantScheme {
198- value : QuantValue :: Q8F | QuantValue :: Q8S ,
196+ value :
197+ QuantValue :: Q8F
198+ | QuantValue :: Q8S
199+ | QuantValue :: E4M3
200+ | QuantValue :: E5M2
201+ | QuantValue :: E2M1 ,
199202 store : QuantStore :: Native ,
200203 ..
201204 } => {
@@ -216,6 +219,12 @@ pub fn launch_ref<R: Runtime, F: Float>(
216219 QuantParam :: BF16 => {
217220 dequantize_native :: < R , F , bf16 > ( client, values, scheme, params, output)
218221 }
222+ QuantParam :: UE8M0 => {
223+ dequantize_native :: < R , F , ue8m0 > ( client, values, scheme, params, output)
224+ }
225+ QuantParam :: UE4M3 => {
226+ dequantize_native :: < R , F , e4m3 > ( client, values, scheme, params, output)
227+ }
219228 }
220229 }
221230 QuantScheme {
@@ -228,7 +237,7 @@ pub fn launch_ref<R: Runtime, F: Float>(
228237 }
229238}
230239
231- fn dequantize_packed < R : Runtime , F : Float , FS : Float > (
240+ fn dequantize_packed < R : Runtime , F : Float , FS : CubePrimitive > (
232241 client : & ComputeClient < R :: Server , R :: Channel > ,
233242 input : & TensorHandleRef < R > ,
234243 scheme : & QuantScheme ,
@@ -268,17 +277,17 @@ fn dequantize_packed<R: Runtime, F: Float, FS: Float>(
268277 cube_count,
269278 cube_dim,
270279 linear_view ( client, input, & line_size_in) ,
271- linear_view ( client, scale, & 1 ) ,
280+ scales_view ( client, input , scale, & 1 , scheme ) ,
272281 linear_view ( client, output, & line_size_out) ,
273- * scheme,
282+ scheme. clone ( ) ,
274283 )
275284 } ;
276285 }
277286 QuantScheme { .. } => panic ! ( "Unsupported quantization scheme {scheme:?}" ) ,
278287 }
279288}
280289
281- fn dequantize_native < R : Runtime , F : Float , FS : Float > (
290+ fn dequantize_native < R : Runtime , F : Float , FS : CubePrimitive > (
282291 client : & ComputeClient < R :: Server , R :: Channel > ,
283292 input : & TensorHandleRef < R > ,
284293 scheme : & QuantScheme ,
@@ -299,19 +308,34 @@ fn dequantize_native<R: Runtime, F: Float, FS: Float>(
299308 QuantScheme {
300309 level : QuantLevel :: Tensor | QuantLevel :: Block ( _) ,
301310 mode : QuantMode :: Symmetric ,
302- value : QuantValue :: Q8F | QuantValue :: Q8S ,
311+ value,
303312 store : QuantStore :: Native ,
304313 ..
305314 } => {
315+ let launch = match value {
316+ QuantValue :: Q8F | QuantValue :: Q8S => {
317+ dequantize_symmetric_native_kernel:: launch_unchecked :: < F , FS , i8 , R >
318+ }
319+ QuantValue :: E4M3 => {
320+ dequantize_symmetric_native_kernel:: launch_unchecked :: < F , FS , e4m3 , R >
321+ }
322+ QuantValue :: E5M2 => {
323+ dequantize_symmetric_native_kernel:: launch_unchecked :: < F , FS , e5m2 , R >
324+ }
325+ QuantValue :: E2M1 => {
326+ dequantize_symmetric_native_kernel:: launch_unchecked :: < F , FS , e2m1x2 , R >
327+ }
328+ other => panic ! ( "Unsupported quantization value {other:?}" ) ,
329+ } ;
330+
306331 unsafe {
307- dequantize_symmetric_int8_native_kernel :: launch_unchecked :: < F , FS , R > (
332+ launch (
308333 client,
309334 cube_count,
310335 cube_dim,
311336 linear_view ( client, input, & line_size) ,
312- linear_view ( client, scale, & 1 ) ,
337+ scales_view ( client, input , scale, & 1 , scheme ) ,
313338 linear_view ( client, output, & line_size) ,
314- * scheme,
315339 )
316340 } ;
317341 }
0 commit comments