Skip to content

Commit 31818e5

Browse files
authored
refactor: Launch (#944)
1 parent bb716c9 commit 31818e5

File tree

32 files changed

+805
-418
lines changed

32 files changed

+805
-418
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ serde_json = { version = "1.0.119", default-features = false }
3535
toml = "0.9.1"
3636
variadics_please = "1"
3737

38+
# no_std compatiblity
3839
dashmap = "6.1.0"
3940
foldhash = { version = "0.1.2", default-features = false }
4041
hashbrown = "0.15.5"
@@ -86,11 +87,10 @@ tracy-client = { version = "0.18.0" }
8687
strum = { version = "0.27.1", features = ["derive"] }
8788
tracel-xtask = { version = "=2.1.8" }
8889

89-
90-
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
9190
portable-atomic = { version = "1.11", default-features = false, features = [
9291
"serde",
9392
] }
93+
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
9494
pretty_assertions = "1.4"
9595

9696
# Async

crates/cubecl-convolution/src/kernels/layered/selector/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use cubecl_core as cubecl;
2-
use cubecl_core::{CubeLaunch, CubeType, prelude::*};
2+
use cubecl_core::{CubeType, prelude::*};
33
use cubecl_std::FastDivmod;
44

55
#[derive(CubeType, CubeLaunch, Clone)]

crates/cubecl-core/src/frontend/container/array/launch.rs

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ use crate::{
77
compute::{KernelBuilder, KernelLauncher},
88
ir::{Id, LineSize, Type},
99
prelude::{
10-
ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
11-
TensorHandleRef,
10+
ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, TensorHandleRef,
1211
},
1312
};
1413

@@ -30,30 +29,6 @@ pub struct ArrayHandleRef<'a, R: Runtime> {
3029
runtime: PhantomData<R>,
3130
}
3231

33-
impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
34-
type CompilationArg = ArrayCompilationArg;
35-
36-
fn expand(
37-
arg: &Self::CompilationArg,
38-
builder: &mut KernelBuilder,
39-
) -> ExpandElementTyped<Array<C>> {
40-
builder
41-
.input_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
42-
.into()
43-
}
44-
fn expand_output(
45-
arg: &Self::CompilationArg,
46-
builder: &mut KernelBuilder,
47-
) -> ExpandElementTyped<Array<C>> {
48-
match arg.inplace {
49-
Some(id) => builder.inplace_output(id).into(),
50-
None => builder
51-
.output_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
52-
.into(),
53-
}
54-
}
55-
}
56-
5732
pub enum ArrayArg<'a, R: Runtime> {
5833
/// The array is passed with an array handle.
5934
Handle {
@@ -153,6 +128,7 @@ impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
153128

154129
impl<C: CubePrimitive> LaunchArg for Array<C> {
155130
type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
131+
type CompilationArg = ArrayCompilationArg;
156132

157133
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
158134
match runtime_arg {
@@ -166,4 +142,24 @@ impl<C: CubePrimitive> LaunchArg for Array<C> {
166142
},
167143
}
168144
}
145+
146+
fn expand(
147+
arg: &Self::CompilationArg,
148+
builder: &mut KernelBuilder,
149+
) -> ExpandElementTyped<Array<C>> {
150+
builder
151+
.input_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
152+
.into()
153+
}
154+
fn expand_output(
155+
arg: &Self::CompilationArg,
156+
builder: &mut KernelBuilder,
157+
) -> ExpandElementTyped<Array<C>> {
158+
match arg.inplace {
159+
Some(id) => builder.inplace_output(id).into(),
160+
None => builder
161+
.output_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
162+
.into(),
163+
}
164+
}
169165
}

crates/cubecl-core/src/frontend/container/sequence/launch.rs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
use std::{cell::RefCell, rc::Rc};
22

3-
use serde::{Deserialize, Serialize};
4-
53
use crate::{
64
Runtime,
75
compute::KernelBuilder,
8-
prelude::{ArgSettings, CompilationArg, LaunchArg, LaunchArgExpand},
6+
prelude::{ArgSettings, CompilationArg, LaunchArg},
97
};
108

119
use super::{Sequence, SequenceExpand};
@@ -29,7 +27,6 @@ impl<'a, R: Runtime, T: LaunchArg> SequenceArg<'a, R, T> {
2927
}
3028
}
3129

32-
#[derive(Serialize, Deserialize)]
3330
pub struct SequenceCompilationArg<C: LaunchArg> {
3431
pub values: Vec<C::CompilationArg>,
3532
}
@@ -65,6 +62,7 @@ impl<C: LaunchArg> core::cmp::Eq for SequenceCompilationArg<C> {}
6562

6663
impl<C: LaunchArg> LaunchArg for Sequence<C> {
6764
type RuntimeArg<'a, R: Runtime> = SequenceArg<'a, R, C>;
65+
type CompilationArg = SequenceCompilationArg<C>;
6866

6967
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
7068
SequenceCompilationArg {
@@ -75,16 +73,6 @@ impl<C: LaunchArg> LaunchArg for Sequence<C> {
7573
.collect(),
7674
}
7775
}
78-
}
79-
80-
impl<R: Runtime, T: LaunchArg> ArgSettings<R> for SequenceArg<'_, R, T> {
81-
fn register(&self, launcher: &mut crate::prelude::KernelLauncher<R>) {
82-
self.values.iter().for_each(|arg| arg.register(launcher));
83-
}
84-
}
85-
86-
impl<C: LaunchArg> LaunchArgExpand for Sequence<C> {
87-
type CompilationArg = SequenceCompilationArg<C>;
8876

8977
fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> SequenceExpand<C> {
9078
let values = arg
@@ -110,3 +98,9 @@ impl<C: LaunchArg> LaunchArgExpand for Sequence<C> {
11098
}
11199
}
112100
}
101+
102+
impl<R: Runtime, T: LaunchArg> ArgSettings<R> for SequenceArg<'_, R, T> {
103+
fn register(&self, launcher: &mut crate::prelude::KernelLauncher<R>) {
104+
self.values.iter().for_each(|arg| arg.register(launcher));
105+
}
106+
}

crates/cubecl-core/src/frontend/container/slice/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
4848
}
4949
}
5050

51-
pub trait SliceVisibility: Clone + Copy {}
51+
pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
5252

5353
impl SliceVisibility for ReadOnly {}
5454

crates/cubecl-core/src/frontend/container/tensor/launch.rs

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ use crate::{
66
Runtime,
77
compute::{KernelBuilder, KernelLauncher},
88
ir::{Id, LineSize, Type},
9-
prelude::{
10-
ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
11-
},
9+
prelude::{ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg},
1210
};
1311

1412
use super::Tensor;
@@ -73,9 +71,23 @@ pub struct TensorCompilationArg {
7371

7472
impl CompilationArg for TensorCompilationArg {}
7573

76-
impl<C: CubePrimitive> LaunchArgExpand for Tensor<C> {
74+
impl<C: CubePrimitive> LaunchArg for Tensor<C> {
75+
type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
7776
type CompilationArg = TensorCompilationArg;
7877

78+
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
79+
match runtime_arg {
80+
TensorArg::Handle { line_size, .. } => TensorCompilationArg {
81+
inplace: None,
82+
line_size: *line_size as u32,
83+
},
84+
TensorArg::Alias { input_pos } => TensorCompilationArg {
85+
inplace: Some(*input_pos as Id),
86+
line_size: 0,
87+
},
88+
}
89+
}
90+
7991
fn expand(
8092
arg: &Self::CompilationArg,
8193
builder: &mut KernelBuilder,
@@ -97,23 +109,6 @@ impl<C: CubePrimitive> LaunchArgExpand for Tensor<C> {
97109
}
98110
}
99111

100-
impl<C: CubePrimitive> LaunchArg for Tensor<C> {
101-
type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
102-
103-
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
104-
match runtime_arg {
105-
TensorArg::Handle { line_size, .. } => TensorCompilationArg {
106-
inplace: None,
107-
line_size: *line_size as u32,
108-
},
109-
TensorArg::Alias { input_pos } => TensorCompilationArg {
110-
inplace: Some(*input_pos as Id),
111-
line_size: 0,
112-
},
113-
}
114-
}
115-
}
116-
117112
impl<'a, R: Runtime> TensorArg<'a, R> {
118113
/// Create a new tensor argument specified with its vectorization factor.
119114
///

crates/cubecl-core/src/frontend/container/tensor/tensormap.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,14 @@ pub struct TensorMapCompilationArg;
119119

120120
impl CompilationArg for TensorMapCompilationArg {}
121121

122-
impl<E: CubePrimitive> LaunchArgExpand for TensorMap<E> {
122+
impl<E: CubePrimitive> LaunchArg for TensorMap<E> {
123+
type RuntimeArg<'a, R: Runtime> = TensorMapArg<'a, R>;
123124
type CompilationArg = TensorMapCompilationArg;
124125

126+
fn compilation_arg<R: Runtime>(_runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
127+
TensorMapCompilationArg
128+
}
129+
125130
fn expand(
126131
_arg: &Self::CompilationArg,
127132
builder: &mut KernelBuilder,
@@ -138,14 +143,6 @@ impl<E: CubePrimitive> LaunchArgExpand for TensorMap<E> {
138143
}
139144
}
140145

141-
impl<E: CubePrimitive> LaunchArg for TensorMap<E> {
142-
type RuntimeArg<'a, R: Runtime> = TensorMapArg<'a, R>;
143-
144-
fn compilation_arg<R: Runtime>(_runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
145-
TensorMapCompilationArg
146-
}
147-
}
148-
149146
/// Commit an async tensor operation. Not sure how this works, poor docs. But you need to call it
150147
/// after a write, but not after reads.
151148
pub fn tma_group_commit() {

crates/cubecl-core/src/frontend/element/atomic.rs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
use cubecl_ir::{AtomicOp, ExpandElement, StorageType};
22

3-
use super::{
4-
ExpandElementIntoMut, ExpandElementTyped, Int, LaunchArgExpand, Numeric,
5-
into_mut_expand_element,
6-
};
3+
use super::{ExpandElementIntoMut, ExpandElementTyped, Int, Numeric, into_mut_expand_element};
74
use crate::{
85
frontend::{CubePrimitive, CubeType},
96
ir::{BinaryOperator, CompareAndSwapOperator, Instruction, Scope, Type, UnaryOperator},
10-
prelude::KernelBuilder,
117
unexpanded,
128
};
139

@@ -317,11 +313,3 @@ impl<Inner: CubePrimitive> ExpandElementIntoMut for Atomic<Inner> {
317313
into_mut_expand_element(scope, elem)
318314
}
319315
}
320-
321-
impl<Inner: CubePrimitive> LaunchArgExpand for Atomic<Inner> {
322-
type CompilationArg = ();
323-
324-
fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
325-
builder.scalar(Self::as_type_native_unchecked()).into()
326-
}
327-
}

0 commit comments

Comments
 (0)