Skip to content

Commit b24b963

Browse files
Feat/event bus (#950)
1 parent 77bd7bc commit b24b963

File tree

5 files changed

+366
-2
lines changed

5 files changed

+366
-2
lines changed

crates/cubecl-macros/src/lib.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,16 @@ fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream>
7474
}));
7575
}
7676
Item::Trait(kernel_trait) => {
77+
let is_debug = args.debug.is_present();
7778
let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;
7879

79-
return Ok(TokenStream::from(quote! {
80+
let tokens = TokenStream::from(quote! {
8081
#expand_trait
81-
}));
82+
});
83+
if is_debug {
84+
panic!("{tokens}");
85+
}
86+
return Ok(tokens);
8287
}
8388
Item::Impl(item_impl) => {
8489
if item_impl.trait_.is_some() {

crates/cubecl-std/src/event/mod.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
use std::{
2+
any::{Any, TypeId},
3+
cell::RefCell,
4+
collections::HashMap,
5+
rc::Rc,
6+
};
7+
8+
use cubecl::prelude::*;
9+
use cubecl_core::{self as cubecl, intrinsic};
10+
11+
#[derive(CubeType, Clone)]
12+
/// This event bus allows users to trigger events at compilation time to modify the generated code.
13+
///
14+
/// # Warning
15+
///
16+
/// Recursion isn't supported with a runtime end condition, the compilation will fail with a
17+
/// strange error.
18+
pub struct ComptimeEventBus {
19+
#[allow(unused)]
20+
#[cube(comptime)]
21+
listener_family: Rc<RefCell<HashMap<TypeId, Vec<EventItem>>>>,
22+
}
23+
24+
type EventItem = Box<dyn Any>;
25+
type Call<E> =
26+
Box<dyn Fn(&mut Scope, &Box<dyn Any>, <E as CubeType>::ExpandType, ComptimeEventBusExpand)>;
27+
28+
struct Payload<E: CubeType> {
29+
listener: Box<dyn Any>,
30+
call: Call<E>,
31+
}
32+
33+
impl Default for ComptimeEventBus {
34+
fn default() -> Self {
35+
Self::new()
36+
}
37+
}
38+
39+
#[cube]
40+
impl ComptimeEventBus {
41+
/// Creates a new event bus.
42+
pub fn new() -> Self {
43+
intrinsic!(|_| {
44+
ComptimeEventBusExpand {
45+
listener_family: Rc::new(RefCell::new(HashMap::new())),
46+
}
47+
})
48+
}
49+
50+
#[allow(unused_variables)]
51+
/// Registers a new callback to be called when its event is launched.
52+
///
53+
/// # Notes
54+
///
55+
/// Multiple listeners for a single event type is supported. All the listeners will be called
56+
/// for each event in the same order they were registered.
57+
pub fn listener<L: EventListener>(&mut self, listener: L) {
58+
intrinsic!(|_| {
59+
let type_id = TypeId::of::<L::Event>();
60+
let mut listeners = self.listener_family.borrow_mut();
61+
62+
// The call dynamic function erases the [EventListener] type.
63+
//
64+
// This is necessary since we need to clone the expand type when calling the expand
65+
// method. The listener is passed as a dynamic type and casted during the function call.
66+
let call =
67+
|scope: &mut cubecl::prelude::Scope,
68+
listener: &Box<dyn Any>,
69+
event: <L::Event as cubecl::prelude::CubeType>::ExpandType,
70+
bus: <ComptimeEventBus as cubecl::prelude::CubeType>::ExpandType| {
71+
let listener: &L::ExpandType = listener.downcast_ref().unwrap();
72+
listener.clone().__expand_on_event_method(scope, event, bus)
73+
};
74+
let call: Call<L::Event> = Box::new(call);
75+
76+
let listener: Box<dyn Any> = Box::new(listener);
77+
let payload = Payload::<L::Event> { listener, call };
78+
79+
// Here we erase the event type, so that all listeners can be stored in the same event
80+
// bus which support multiple event types.
81+
let listener_dyn: Box<dyn Any> = Box::new(payload);
82+
83+
match listeners.get_mut(&type_id) {
84+
Some(list) => list.push(listener_dyn),
85+
None => {
86+
listeners.insert(type_id, vec![listener_dyn]);
87+
}
88+
}
89+
})
90+
}
91+
92+
#[allow(unused_variables)]
93+
/// Registers a new event to be processed by [registered listeners](EventListener).
94+
pub fn event<E: CubeType + 'static>(&mut self, event: E) {
95+
intrinsic!(|scope| {
96+
let type_id = TypeId::of::<E>();
97+
let family = self.listener_family.borrow();
98+
let listeners = match family.get(&type_id) {
99+
Some(val) => val,
100+
None => return,
101+
};
102+
103+
for listener in listeners.iter() {
104+
let payload = listener.downcast_ref::<Payload<E>>().unwrap();
105+
let call = &payload.call;
106+
call(scope, &payload.listener, event.clone(), self.clone());
107+
}
108+
})
109+
}
110+
}
111+
112+
#[cube]
113+
/// Defines a listener that is called each time an event is triggered on an
114+
/// [event bus](ComptimeEventBus).
115+
pub trait EventListener: 'static {
116+
/// The event type triggering the [EventListener::on_event] callback.
117+
type Event: CubeType + 'static;
118+
119+
/// The function called when an event of the type [EventListener::Event] is registered on the
120+
/// [ComptimeEventBus].
121+
fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus);
122+
}

crates/cubecl-std/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,8 @@ pub mod quant;
2121
pub mod scalar;
2222
pub mod tensor;
2323

24+
/// Event utilities.
25+
pub mod event;
26+
2427
#[cfg(feature = "export_tests")]
2528
pub mod tests;
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
use crate::event::{ComptimeEventBus, EventListener, EventListenerExpand};
2+
use cubecl::prelude::*;
3+
use cubecl_core as cubecl;
4+
5+
#[derive(CubeType)]
6+
pub struct EventUInt {
7+
#[cube(comptime)]
8+
pub value: u32,
9+
}
10+
11+
#[derive(CubeType)]
12+
pub struct EventFloat {
13+
#[cube(comptime)]
14+
pub value: f32,
15+
}
16+
17+
#[derive(CubeType, Clone)]
18+
pub struct EventListenerPosZero {
19+
items: SliceMut<f32>,
20+
}
21+
22+
#[derive(CubeType, Clone)]
23+
pub struct EventListenerPosOne {
24+
items: SliceMut<f32>,
25+
}
26+
27+
#[derive(CubeType, Clone)]
28+
pub struct EventListenerPosTwo {
29+
items: SliceMut<f32>,
30+
times: ComptimeCell<Counter>,
31+
}
32+
33+
#[derive(CubeType, Clone)]
34+
pub struct Counter {
35+
#[cube(comptime)]
36+
value: u32,
37+
}
38+
39+
#[cube]
40+
impl EventListener for EventListenerPosZero {
41+
type Event = EventUInt;
42+
43+
fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus) {
44+
if comptime!(event.value < 10) {
45+
comment!("On event pos zero < 10");
46+
bus.event::<EventUInt>(EventUInt {
47+
value: comptime!(15u32 + event.value),
48+
});
49+
} else {
50+
comment!("On event pos zero >= 10");
51+
self.items[0] = f32::cast_from(event.value);
52+
}
53+
}
54+
}
55+
56+
#[cube]
57+
impl EventListener for EventListenerPosOne {
58+
type Event = EventUInt;
59+
60+
fn on_event(&mut self, event: Self::Event, _bus: &mut ComptimeEventBus) {
61+
comment!("On event pos one");
62+
self.items[1] = (f32::cast_from(event.value) * 2.0) + self.items[1];
63+
}
64+
}
65+
66+
#[cube]
67+
impl EventListener for EventListenerPosTwo {
68+
type Event = EventFloat;
69+
70+
fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus) {
71+
comment!("On event pos two");
72+
self.items[2] = event.value + self.items[2];
73+
74+
let times = self.times.read();
75+
self.times.store(Counter {
76+
value: comptime!(times.value + 1),
77+
});
78+
79+
if comptime!(times.value < 4) {
80+
bus.event::<EventFloat>(EventFloat {
81+
value: comptime!(event.value * 2.0),
82+
});
83+
bus.event::<EventUInt>(EventUInt {
84+
value: comptime!((event.value * 2.0) as u32),
85+
});
86+
}
87+
}
88+
}
89+
90+
#[cube]
91+
fn test_1(items: SliceMut<f32>) {
92+
let mut bus = ComptimeEventBus::new();
93+
let listener_zero = EventListenerPosZero { items };
94+
let listener_one = EventListenerPosOne { items };
95+
96+
bus.listener::<EventListenerPosZero>(listener_zero);
97+
bus.listener::<EventListenerPosOne>(listener_one);
98+
99+
bus.event::<EventUInt>(EventUInt { value: 5u32 });
100+
}
101+
102+
#[cube]
103+
fn test_2(items: SliceMut<f32>) {
104+
let mut bus = ComptimeEventBus::new();
105+
let listener_zero = EventListenerPosZero { items };
106+
let listener_one = EventListenerPosOne { items };
107+
108+
bus.listener::<EventListenerPosZero>(listener_zero);
109+
bus.listener::<EventListenerPosOne>(listener_one);
110+
111+
bus.event::<EventUInt>(EventUInt { value: 15u32 });
112+
}
113+
114+
#[cube]
115+
fn test_3(items: SliceMut<f32>) {
116+
let mut bus = ComptimeEventBus::new();
117+
let listener_zero = EventListenerPosZero { items };
118+
let listener_one = EventListenerPosOne { items };
119+
let listener_two = EventListenerPosTwo {
120+
items,
121+
times: ComptimeCell::new(Counter { value: 0u32 }),
122+
};
123+
124+
bus.listener::<EventListenerPosZero>(listener_zero);
125+
bus.listener::<EventListenerPosOne>(listener_one);
126+
bus.listener::<EventListenerPosTwo>(listener_two);
127+
128+
bus.event::<EventFloat>(EventFloat { value: 15.0f32 });
129+
}
130+
131+
#[cube(launch_unchecked)]
132+
fn launch_test_1(output: &mut Array<f32>) {
133+
output[0] = 0.0;
134+
output[1] = 0.0;
135+
test_1(output.to_slice_mut());
136+
}
137+
138+
#[cube(launch_unchecked)]
139+
fn launch_test_2(output: &mut Array<f32>) {
140+
output[0] = 0.0;
141+
output[1] = 0.0;
142+
test_2(output.to_slice_mut());
143+
}
144+
145+
#[cube(launch_unchecked)]
146+
fn launch_test_3(output: &mut Array<f32>) {
147+
output[0] = 0.0;
148+
output[1] = 0.0;
149+
output[2] = 0.0;
150+
test_3(output.to_slice_mut());
151+
}
152+
153+
pub fn event_test_1<R: Runtime>(client: ComputeClient<R::Server>) {
154+
let output = client.empty(8);
155+
156+
unsafe {
157+
launch_test_1::launch_unchecked::<R>(
158+
&client,
159+
CubeCount::Static(1, 1, 1),
160+
CubeDim { x: 1, y: 1, z: 1 },
161+
ArrayArg::from_raw_parts::<f32>(&output, 2, 1),
162+
);
163+
}
164+
165+
let bytes = client.read_one(output);
166+
let actual = f32::from_bytes(&bytes);
167+
168+
assert_eq!(actual, &[20.0, 50.0]);
169+
}
170+
171+
pub fn event_test_2<R: Runtime>(client: ComputeClient<R::Server>) {
172+
let output = client.empty(8);
173+
174+
unsafe {
175+
launch_test_2::launch_unchecked::<R>(
176+
&client,
177+
CubeCount::Static(1, 1, 1),
178+
CubeDim { x: 1, y: 1, z: 1 },
179+
ArrayArg::from_raw_parts::<f32>(&output, 2, 1),
180+
);
181+
}
182+
183+
let bytes = client.read_one(output);
184+
let actual = f32::from_bytes(&bytes);
185+
186+
assert_eq!(actual, &[15.0, 30.0]);
187+
}
188+
189+
pub fn event_test_3<R: Runtime>(client: ComputeClient<R::Server>) {
190+
let output = client.empty(12);
191+
192+
unsafe {
193+
launch_test_3::launch_unchecked::<R>(
194+
&client,
195+
CubeCount::Static(1, 1, 1),
196+
CubeDim { x: 1, y: 1, z: 1 },
197+
ArrayArg::from_raw_parts::<f32>(&output, 3, 1),
198+
);
199+
}
200+
201+
let bytes = client.read_one(output);
202+
let actual = f32::from_bytes(&bytes);
203+
204+
assert_eq!(actual, &[30.0, 900.0, 465.0]);
205+
}
206+
207+
#[macro_export]
208+
macro_rules! testgen_event {
209+
() => {
210+
mod event {
211+
use super::*;
212+
213+
#[test]
214+
fn test_1() {
215+
let client = TestRuntime::client(&Default::default());
216+
cubecl_std::tests::event::event_test_1::<TestRuntime>(client);
217+
}
218+
219+
#[test]
220+
fn test_2() {
221+
let client = TestRuntime::client(&Default::default());
222+
cubecl_std::tests::event::event_test_2::<TestRuntime>(client);
223+
}
224+
225+
#[test]
226+
fn test_3() {
227+
let client = TestRuntime::client(&Default::default());
228+
cubecl_std::tests::event::event_test_3::<TestRuntime>(client);
229+
}
230+
}
231+
};
232+
}

crates/cubecl-std/src/tests/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod event;
12
pub mod reinterpret_slice;
23
pub mod tensor;
34
pub mod trigonometry;
@@ -12,6 +13,7 @@ macro_rules! testgen {
1213

1314
cubecl_std::testgen_reinterpret_slice!();
1415
cubecl_std::testgen_trigonometry!();
16+
cubecl_std::testgen_event!();
1517
}
1618
};
1719
}

0 commit comments

Comments
 (0)