1- use std:: marker:: PhantomData ;
1+ use std:: { collections :: BTreeMap , marker:: PhantomData } ;
22
3- use crate :: MetadataBuilder ;
3+ use crate :: KernelSettings ;
44use crate :: Runtime ;
55use crate :: compute:: KernelTask ;
66use crate :: prelude:: { ArrayArg , TensorArg , TensorMapArg } ;
7- use crate :: { KernelSettings , prelude :: CubePrimitive } ;
8- use bytemuck :: { AnyBitPattern , NoUninit } ;
7+ use crate :: { CubeScalar , MetadataBuilder } ;
8+ use cubecl_ir :: StorageType ;
99use cubecl_runtime:: server:: { Binding , CubeCount , ScalarBinding , TensorMapBinding } ;
1010use cubecl_runtime:: { client:: ComputeClient , server:: Bindings } ;
1111
@@ -14,18 +14,7 @@ use super::CubeKernel;
1414/// Prepare a kernel for [launch](KernelLauncher::launch).
1515pub struct KernelLauncher < R : Runtime > {
1616 tensors : TensorState < R > ,
17- scalar_bf16 : ScalarState < half:: bf16 > ,
18- scalar_f16 : ScalarState < half:: f16 > ,
19- scalar_f32 : ScalarState < f32 > ,
20- scalar_f64 : ScalarState < f64 > ,
21- scalar_u64 : ScalarState < u64 > ,
22- scalar_u32 : ScalarState < u32 > ,
23- scalar_u16 : ScalarState < u16 > ,
24- scalar_u8 : ScalarState < u8 > ,
25- scalar_i64 : ScalarState < i64 > ,
26- scalar_i32 : ScalarState < i32 > ,
27- scalar_i16 : ScalarState < i16 > ,
28- scalar_i8 : ScalarState < i8 > ,
17+ scalars : ScalarState ,
2918 pub settings : KernelSettings ,
3019 runtime : PhantomData < R > ,
3120}
@@ -46,64 +35,14 @@ impl<R: Runtime> KernelLauncher<R> {
4635 self . tensors . push_array ( array) ;
4736 }
4837
49- /// Register a u8 scalar to be launched.
50- pub fn register_u8 ( & mut self , scalar : u8 ) {
51- self . scalar_u8 . push ( scalar) ;
38+ /// Register a scalar to be launched.
39+ pub fn register_scalar < C : CubeScalar > ( & mut self , scalar : C ) {
40+ self . scalars . push ( scalar) ;
5241 }
5342
54- /// Register a u16 scalar to be launched.
55- pub fn register_u16 ( & mut self , scalar : u16 ) {
56- self . scalar_u16 . push ( scalar) ;
57- }
58-
59- /// Register a u32 scalar to be launched.
60- pub fn register_u32 ( & mut self , scalar : u32 ) {
61- self . scalar_u32 . push ( scalar) ;
62- }
63-
64- /// Register a u64 scalar to be launched.
65- pub fn register_u64 ( & mut self , scalar : u64 ) {
66- self . scalar_u64 . push ( scalar) ;
67- }
68-
69- /// Register a i8 scalar to be launched.
70- pub fn register_i8 ( & mut self , scalar : i8 ) {
71- self . scalar_i8 . push ( scalar) ;
72- }
73-
74- /// Register a i16 scalar to be launched.
75- pub fn register_i16 ( & mut self , scalar : i16 ) {
76- self . scalar_i16 . push ( scalar) ;
77- }
78-
79- /// Register a i32 scalar to be launched.
80- pub fn register_i32 ( & mut self , scalar : i32 ) {
81- self . scalar_i32 . push ( scalar) ;
82- }
83-
84- /// Register a i64 scalar to be launched.
85- pub fn register_i64 ( & mut self , scalar : i64 ) {
86- self . scalar_i64 . push ( scalar) ;
87- }
88-
89- /// Register a bf16 scalar to be launched.
90- pub fn register_bf16 ( & mut self , scalar : half:: bf16 ) {
91- self . scalar_bf16 . push ( scalar) ;
92- }
93-
94- /// Register a f16 scalar to be launched.
95- pub fn register_f16 ( & mut self , scalar : half:: f16 ) {
96- self . scalar_f16 . push ( scalar) ;
97- }
98-
99- /// Register a f32 scalar to be launched.
100- pub fn register_f32 ( & mut self , scalar : f32 ) {
101- self . scalar_f32 . push ( scalar) ;
102- }
103-
104- /// Register a f64 scalar to be launched.
105- pub fn register_f64 ( & mut self , scalar : f64 ) {
106- self . scalar_f64 . push ( scalar) ;
43+ /// Register a scalar to be launched from raw data.
44+ pub fn register_scalar_raw ( & mut self , bytes : & [ u8 ] , dtype : StorageType ) {
45+ self . scalars . push_raw ( bytes, dtype) ;
10746 }
10847
10948 /// Launch the kernel.
@@ -156,19 +95,7 @@ impl<R: Runtime> KernelLauncher<R> {
15695 let mut bindings = Bindings :: new ( ) ;
15796
15897 self . tensors . register ( & mut bindings) ;
159-
160- self . scalar_u8 . register ( & mut bindings) ;
161- self . scalar_u16 . register ( & mut bindings) ;
162- self . scalar_u32 . register ( & mut bindings) ;
163- self . scalar_u64 . register ( & mut bindings) ;
164- self . scalar_i8 . register ( & mut bindings) ;
165- self . scalar_i16 . register ( & mut bindings) ;
166- self . scalar_i32 . register ( & mut bindings) ;
167- self . scalar_i64 . register ( & mut bindings) ;
168- self . scalar_f16 . register ( & mut bindings) ;
169- self . scalar_bf16 . register ( & mut bindings) ;
170- self . scalar_f32 . register ( & mut bindings) ;
171- self . scalar_f64 . register ( & mut bindings) ;
98+ self . scalars . register ( & mut bindings) ;
17299
173100 bindings
174101 }
@@ -190,13 +117,14 @@ pub enum TensorState<R: Runtime> {
190117/// Handles the scalar state of an element type
191118///
192119/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
193- pub enum ScalarState < T > {
194- /// No scalar of that type is registered yet.
195- Empty ,
196- /// The registered scalars.
197- Some ( Vec < T > ) ,
120+ #[ derive( Default , Clone ) ]
121+ pub struct ScalarState {
122+ data : BTreeMap < StorageType , ScalarValues > ,
198123}
199124
125+ /// Stores the data and type for a scalar arg
126+ pub type ScalarValues = Vec < u8 > ;
127+
200128impl < R : Runtime > TensorState < R > {
201129 fn maybe_init ( & mut self ) {
202130 if matches ! ( self , TensorState :: Empty ) {
@@ -316,26 +244,36 @@ impl<R: Runtime> TensorState<R> {
316244 }
317245}
318246
319- impl < T : NoUninit + AnyBitPattern + CubePrimitive > ScalarState < T > {
247+ impl ScalarState {
320248 /// Add a new scalar value to the state.
321- pub fn push ( & mut self , val : T ) {
322- match self {
323- ScalarState :: Empty => * self = Self :: Some ( vec ! [ val] ) ,
324- ScalarState :: Some ( values) => values. push ( val) ,
325- }
249+ pub fn push < T : CubeScalar > ( & mut self , val : T ) {
250+ let val = [ val] ;
251+ let bytes = T :: as_bytes ( & val) ;
252+ self . data
253+ . entry ( T :: cube_type ( ) )
254+ . or_default ( )
255+ . extend ( bytes. iter ( ) . copied ( ) ) ;
256+ }
257+
258+ /// Add a new raw value to the state.
259+ pub fn push_raw ( & mut self , bytes : & [ u8 ] , dtype : StorageType ) {
260+ self . data
261+ . entry ( dtype)
262+ . or_default ( )
263+ . extend ( bytes. iter ( ) . copied ( ) ) ;
326264 }
327265
328266 fn register ( & self , bindings : & mut Bindings ) {
329- if let ScalarState :: Some ( values) = self {
330- let len = values. len ( ) ;
331- let len_u64 = len. div_ceil ( size_of :: < u64 > ( ) / size_of :: < T > ( ) ) ;
267+ for ( ty, values) in self . data . iter ( ) {
268+ let len = values. len ( ) / ty. size ( ) ;
269+ let len_u64 = len. div_ceil ( size_of :: < u64 > ( ) / ty. size ( ) ) ;
270+
332271 let mut data = vec ! [ 0 ; len_u64] ;
333- let slice = bytemuck:: cast_slice_mut :: < u64 , T > ( & mut data) ;
272+ let slice = bytemuck:: cast_slice_mut :: < u64 , u8 > ( & mut data) ;
334273 slice[ 0 ..values. len ( ) ] . copy_from_slice ( values) ;
335- let elem = T :: as_type_native_unchecked ( ) ;
336274 bindings
337275 . scalars
338- . insert ( elem , ScalarBinding :: new ( elem , len, data) ) ;
276+ . insert ( * ty , ScalarBinding :: new ( * ty , len, data) ) ;
339277 }
340278 }
341279}
@@ -344,18 +282,7 @@ impl<R: Runtime> Default for KernelLauncher<R> {
344282 fn default ( ) -> Self {
345283 Self {
346284 tensors : TensorState :: Empty ,
347- scalar_bf16 : ScalarState :: Empty ,
348- scalar_f16 : ScalarState :: Empty ,
349- scalar_f32 : ScalarState :: Empty ,
350- scalar_f64 : ScalarState :: Empty ,
351- scalar_u64 : ScalarState :: Empty ,
352- scalar_u32 : ScalarState :: Empty ,
353- scalar_u16 : ScalarState :: Empty ,
354- scalar_u8 : ScalarState :: Empty ,
355- scalar_i64 : ScalarState :: Empty ,
356- scalar_i32 : ScalarState :: Empty ,
357- scalar_i16 : ScalarState :: Empty ,
358- scalar_i8 : ScalarState :: Empty ,
285+ scalars : Default :: default ( ) ,
359286 settings : Default :: default ( ) ,
360287 runtime : PhantomData ,
361288 }
0 commit comments