Skip to content

Commit 77bd7bc

Browse files
authored
refactor: Scalars (#1064)
1 parent c2617b8 commit 77bd7bc

File tree

18 files changed

+153
-491
lines changed

18 files changed

+153
-491
lines changed

crates/cubecl-common/src/format.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ mod tests {
104104
#[test]
105105
fn test_format_debug() {
106106
let test = Test {
107-
map: HashMap::from_iter([("Hey with space".to_string(), 8)].into_iter()),
107+
map: HashMap::from_iter([("Hey with space".to_string(), 8)]),
108108
};
109109

110110
let formatted = format_debug(&test);

crates/cubecl-core/src/compute/launcher.rs

Lines changed: 41 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use std::marker::PhantomData;
1+
use std::{collections::BTreeMap, marker::PhantomData};
22

3-
use crate::MetadataBuilder;
3+
use crate::KernelSettings;
44
use crate::Runtime;
55
use crate::compute::KernelTask;
66
use crate::prelude::{ArrayArg, TensorArg, TensorMapArg};
7-
use crate::{KernelSettings, prelude::CubePrimitive};
8-
use bytemuck::{AnyBitPattern, NoUninit};
7+
use crate::{CubeScalar, MetadataBuilder};
8+
use cubecl_ir::StorageType;
99
use cubecl_runtime::server::{Binding, CubeCount, ScalarBinding, TensorMapBinding};
1010
use cubecl_runtime::{client::ComputeClient, server::Bindings};
1111

@@ -14,18 +14,7 @@ use super::CubeKernel;
1414
/// Prepare a kernel for [launch](KernelLauncher::launch).
1515
pub struct KernelLauncher<R: Runtime> {
1616
tensors: TensorState<R>,
17-
scalar_bf16: ScalarState<half::bf16>,
18-
scalar_f16: ScalarState<half::f16>,
19-
scalar_f32: ScalarState<f32>,
20-
scalar_f64: ScalarState<f64>,
21-
scalar_u64: ScalarState<u64>,
22-
scalar_u32: ScalarState<u32>,
23-
scalar_u16: ScalarState<u16>,
24-
scalar_u8: ScalarState<u8>,
25-
scalar_i64: ScalarState<i64>,
26-
scalar_i32: ScalarState<i32>,
27-
scalar_i16: ScalarState<i16>,
28-
scalar_i8: ScalarState<i8>,
17+
scalars: ScalarState,
2918
pub settings: KernelSettings,
3019
runtime: PhantomData<R>,
3120
}
@@ -46,64 +35,14 @@ impl<R: Runtime> KernelLauncher<R> {
4635
self.tensors.push_array(array);
4736
}
4837

49-
/// Register a u8 scalar to be launched.
50-
pub fn register_u8(&mut self, scalar: u8) {
51-
self.scalar_u8.push(scalar);
38+
/// Register a scalar to be launched.
39+
pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
40+
self.scalars.push(scalar);
5241
}
5342

54-
/// Register a u16 scalar to be launched.
55-
pub fn register_u16(&mut self, scalar: u16) {
56-
self.scalar_u16.push(scalar);
57-
}
58-
59-
/// Register a u32 scalar to be launched.
60-
pub fn register_u32(&mut self, scalar: u32) {
61-
self.scalar_u32.push(scalar);
62-
}
63-
64-
/// Register a u64 scalar to be launched.
65-
pub fn register_u64(&mut self, scalar: u64) {
66-
self.scalar_u64.push(scalar);
67-
}
68-
69-
/// Register a i8 scalar to be launched.
70-
pub fn register_i8(&mut self, scalar: i8) {
71-
self.scalar_i8.push(scalar);
72-
}
73-
74-
/// Register a i16 scalar to be launched.
75-
pub fn register_i16(&mut self, scalar: i16) {
76-
self.scalar_i16.push(scalar);
77-
}
78-
79-
/// Register a i32 scalar to be launched.
80-
pub fn register_i32(&mut self, scalar: i32) {
81-
self.scalar_i32.push(scalar);
82-
}
83-
84-
/// Register a i64 scalar to be launched.
85-
pub fn register_i64(&mut self, scalar: i64) {
86-
self.scalar_i64.push(scalar);
87-
}
88-
89-
/// Register a bf16 scalar to be launched.
90-
pub fn register_bf16(&mut self, scalar: half::bf16) {
91-
self.scalar_bf16.push(scalar);
92-
}
93-
94-
/// Register a f16 scalar to be launched.
95-
pub fn register_f16(&mut self, scalar: half::f16) {
96-
self.scalar_f16.push(scalar);
97-
}
98-
99-
/// Register a f32 scalar to be launched.
100-
pub fn register_f32(&mut self, scalar: f32) {
101-
self.scalar_f32.push(scalar);
102-
}
103-
104-
/// Register a f64 scalar to be launched.
105-
pub fn register_f64(&mut self, scalar: f64) {
106-
self.scalar_f64.push(scalar);
43+
/// Register a scalar to be launched from raw data.
44+
pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
45+
self.scalars.push_raw(bytes, dtype);
10746
}
10847

10948
/// Launch the kernel.
@@ -156,19 +95,7 @@ impl<R: Runtime> KernelLauncher<R> {
15695
let mut bindings = Bindings::new();
15796

15897
self.tensors.register(&mut bindings);
159-
160-
self.scalar_u8.register(&mut bindings);
161-
self.scalar_u16.register(&mut bindings);
162-
self.scalar_u32.register(&mut bindings);
163-
self.scalar_u64.register(&mut bindings);
164-
self.scalar_i8.register(&mut bindings);
165-
self.scalar_i16.register(&mut bindings);
166-
self.scalar_i32.register(&mut bindings);
167-
self.scalar_i64.register(&mut bindings);
168-
self.scalar_f16.register(&mut bindings);
169-
self.scalar_bf16.register(&mut bindings);
170-
self.scalar_f32.register(&mut bindings);
171-
self.scalar_f64.register(&mut bindings);
98+
self.scalars.register(&mut bindings);
17299

173100
bindings
174101
}
@@ -190,13 +117,14 @@ pub enum TensorState<R: Runtime> {
190117
/// Handles the scalar state of an element type
191118
///
192119
/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
193-
pub enum ScalarState<T> {
194-
/// No scalar of that type is registered yet.
195-
Empty,
196-
/// The registered scalars.
197-
Some(Vec<T>),
120+
#[derive(Default, Clone)]
121+
pub struct ScalarState {
122+
data: BTreeMap<StorageType, ScalarValues>,
198123
}
199124

125+
/// Stores the data and type for a scalar arg
126+
pub type ScalarValues = Vec<u8>;
127+
200128
impl<R: Runtime> TensorState<R> {
201129
fn maybe_init(&mut self) {
202130
if matches!(self, TensorState::Empty) {
@@ -316,26 +244,36 @@ impl<R: Runtime> TensorState<R> {
316244
}
317245
}
318246

319-
impl<T: NoUninit + AnyBitPattern + CubePrimitive> ScalarState<T> {
247+
impl ScalarState {
320248
/// Add a new scalar value to the state.
321-
pub fn push(&mut self, val: T) {
322-
match self {
323-
ScalarState::Empty => *self = Self::Some(vec![val]),
324-
ScalarState::Some(values) => values.push(val),
325-
}
249+
pub fn push<T: CubeScalar>(&mut self, val: T) {
250+
let val = [val];
251+
let bytes = T::as_bytes(&val);
252+
self.data
253+
.entry(T::cube_type())
254+
.or_default()
255+
.extend(bytes.iter().copied());
256+
}
257+
258+
/// Add a new raw value to the state.
259+
pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
260+
self.data
261+
.entry(dtype)
262+
.or_default()
263+
.extend(bytes.iter().copied());
326264
}
327265

328266
fn register(&self, bindings: &mut Bindings) {
329-
if let ScalarState::Some(values) = self {
330-
let len = values.len();
331-
let len_u64 = len.div_ceil(size_of::<u64>() / size_of::<T>());
267+
for (ty, values) in self.data.iter() {
268+
let len = values.len() / ty.size();
269+
let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
270+
332271
let mut data = vec![0; len_u64];
333-
let slice = bytemuck::cast_slice_mut::<u64, T>(&mut data);
272+
let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
334273
slice[0..values.len()].copy_from_slice(values);
335-
let elem = T::as_type_native_unchecked();
336274
bindings
337275
.scalars
338-
.insert(elem, ScalarBinding::new(elem, len, data));
276+
.insert(*ty, ScalarBinding::new(*ty, len, data));
339277
}
340278
}
341279
}
@@ -344,18 +282,7 @@ impl<R: Runtime> Default for KernelLauncher<R> {
344282
fn default() -> Self {
345283
Self {
346284
tensors: TensorState::Empty,
347-
scalar_bf16: ScalarState::Empty,
348-
scalar_f16: ScalarState::Empty,
349-
scalar_f32: ScalarState::Empty,
350-
scalar_f64: ScalarState::Empty,
351-
scalar_u64: ScalarState::Empty,
352-
scalar_u32: ScalarState::Empty,
353-
scalar_u16: ScalarState::Empty,
354-
scalar_u8: ScalarState::Empty,
355-
scalar_i64: ScalarState::Empty,
356-
scalar_i32: ScalarState::Empty,
357-
scalar_i16: ScalarState::Empty,
358-
scalar_i8: ScalarState::Empty,
285+
scalars: Default::default(),
359286
settings: Default::default(),
360287
runtime: PhantomData,
361288
}

crates/cubecl-core/src/frontend/container/line/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ mod fill {
7979
let length = self.expand.ty.line_size();
8080
let output = scope.create_local(Type::new(P::as_type(scope)).line(length));
8181

82-
cast::expand::<P>(scope, value, output.clone().into());
82+
cast::expand::<P, Line<P>>(scope, value, output.clone().into());
8383

8484
output.into()
8585
})

crates/cubecl-core/src/frontend/element/cast.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ pub trait Cast: CubePrimitive {
1414
scope: &mut Scope,
1515
value: ExpandElementTyped<From>,
1616
) -> <Self as CubeType>::ExpandType {
17-
if core::any::TypeId::of::<Self>() == core::any::TypeId::of::<From>() {
17+
if Self::as_type(scope) == From::as_type(scope) {
1818
return value.expand.into();
1919
}
2020
let line_size_in = value.expand.ty.line_size();
2121
let line_size_out = line_size_in * value.expand.ty.storage_type().packing_factor()
2222
/ Self::as_type(scope).packing_factor();
2323
let new_var = scope
2424
.create_local(Type::new(<Self as CubePrimitive>::as_type(scope)).line(line_size_out));
25-
cast::expand(scope, value, new_var.clone().into());
25+
cast::expand::<From, Self>(scope, value, new_var.clone().into());
2626
new_var.into()
2727
}
2828
}

crates/cubecl-core/src/frontend/element/float.rs

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -164,27 +164,3 @@ impl_float!(half f16, F16);
164164
impl_float!(half bf16, BF16);
165165
impl_float!(f32, F32);
166166
impl_float!(f64, F64);
167-
168-
impl ScalarArgSettings for f16 {
169-
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
170-
settings.register_f16(*self);
171-
}
172-
}
173-
174-
impl ScalarArgSettings for bf16 {
175-
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
176-
settings.register_bf16(*self);
177-
}
178-
}
179-
180-
impl ScalarArgSettings for f32 {
181-
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
182-
settings.register_f32(*self);
183-
}
184-
}
185-
186-
impl ScalarArgSettings for f64 {
187-
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
188-
settings.register_f64(*self);
189-
}
190-
}

crates/cubecl-core/src/frontend/element/float/fp4.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
use cubecl_common::{e2m1, e2m1x2};
22
use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
33

4-
use crate::{
5-
Runtime,
6-
compute::KernelLauncher,
7-
prelude::{
8-
CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime,
9-
ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
10-
},
4+
use crate::prelude::{
5+
CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime,
6+
into_mut_expand_element, into_runtime_expand_element,
117
};
128

139
impl CubeType for e2m1 {
@@ -73,9 +69,3 @@ impl ExpandElementIntoMut for e2m1x2 {
7369
into_mut_expand_element(scope, elem)
7470
}
7571
}
76-
77-
impl ScalarArgSettings for e2m1x2 {
78-
fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
79-
todo!("Not yet supported for scalars")
80-
}
81-
}

crates/cubecl-core/src/frontend/element/float/fp8.rs

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
use cubecl_common::{e4m3, e5m2, ue8m0};
22
use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
33

4-
use crate::{
5-
Runtime,
6-
compute::KernelLauncher,
7-
prelude::{
8-
CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, Numeric,
9-
ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
10-
},
4+
use crate::prelude::{
5+
CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, Numeric,
6+
into_mut_expand_element, into_runtime_expand_element,
117
};
128

139
impl CubeType for e4m3 {
@@ -50,12 +46,6 @@ impl ExpandElementIntoMut for e4m3 {
5046
}
5147
}
5248

53-
impl ScalarArgSettings for e4m3 {
54-
fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
55-
todo!("Not yet supported for scalars")
56-
}
57-
}
58-
5949
impl CubeType for e5m2 {
6050
type ExpandType = ExpandElementTyped<e5m2>;
6151
}
@@ -96,12 +86,6 @@ impl ExpandElementIntoMut for e5m2 {
9686
}
9787
}
9888

99-
impl ScalarArgSettings for e5m2 {
100-
fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
101-
todo!("Not yet supported for scalars")
102-
}
103-
}
104-
10589
impl CubeType for ue8m0 {
10690
type ExpandType = ExpandElementTyped<ue8m0>;
10791
}
@@ -141,9 +125,3 @@ impl ExpandElementIntoMut for ue8m0 {
141125
into_mut_expand_element(scope, elem)
142126
}
143127
}
144-
145-
impl ScalarArgSettings for ue8m0 {
146-
fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
147-
todo!("Not yet supported for scalars")
148-
}
149-
}

crates/cubecl-core/src/frontend/element/float/relaxed.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::prelude::{Numeric, into_runtime_expand_element};
55

66
use super::{
77
CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, Float, IntoRuntime,
8-
KernelLauncher, Runtime, ScalarArgSettings, into_mut_expand_element,
8+
into_mut_expand_element,
99
};
1010

1111
impl CubeType for flex32 {
@@ -79,9 +79,3 @@ impl Float for flex32 {
7979
flex32::from_f32(val)
8080
}
8181
}
82-
83-
impl ScalarArgSettings for flex32 {
84-
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
85-
settings.register_f32(self.to_f32());
86-
}
87-
}

crates/cubecl-core/src/frontend/element/float/tensor_float.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::prelude::{Numeric, into_runtime_expand_element};
66

77
use super::{
88
CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, Float, IntoRuntime,
9-
KernelLauncher, Runtime, ScalarArgSettings, into_mut_expand_element,
9+
into_mut_expand_element,
1010
};
1111

1212
impl CubeType for tf32 {
@@ -81,9 +81,3 @@ impl Float for tf32 {
8181
tf32::from_f32(val)
8282
}
8383
}
84-
85-
impl ScalarArgSettings for tf32 {
86-
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
87-
settings.register_f32((*self).to_f32());
88-
}
89-
}

0 commit comments

Comments
 (0)