Skip to content

Commit 011f553

Browse files
Fix/wgpu multi streams (#961)
* Fix * Cleanup * Removes extra flush
1 parent 290a4c7 commit 011f553

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

crates/cubecl-runtime/src/stream/scheduler.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ pub trait SchedulerStreamBackend {
1818

1919
/// Enqueues a task onto a given stream for execution.
2020
fn enqueue(task: Self::Task, stream: &mut Self::Stream);
21+
/// Flush the inner stream queue to ensure ordering between different streams.
22+
fn flush(stream: &mut Self::Stream);
2123
/// Returns a mutable reference to the stream factory.
2224
fn factory(&mut self) -> &mut Self::Factory;
2325
}
@@ -219,11 +221,28 @@ impl<B: SchedulerStreamBackend> SchedulerMultiStream<B> {
219221
// Enqueue each task on the stream.
220222
B::enqueue(task, &mut stream.stream);
221223
}
224+
225+
// Makes sure the tasks are ordered on the compute queue.
226+
B::flush(&mut stream.stream);
222227
}
223228
}
224229

225-
/// Executes schedules in an interleaved manner, alternating tasks across streams.
230+
//// Executes schedules in an interleaved manner, alternating tasks from different streams.
231+
///
232+
/// We chose the first stream as the one executing the tasks, ensuring proper ordering by
233+
/// flushing all other streams first and flushing the execution stream at the end.
234+
/// This way, we ensure that most tasks are actually interleaved on the real compute queue
235+
/// shared across all streams.
226236
fn execute_schedules_interleave(&mut self, mut schedules: Vec<Schedule<B>>) {
237+
// Makes sure the tasks are ordered on the compute queue.
238+
for schedule in schedules.iter_mut().skip(1) {
239+
let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) };
240+
B::flush(&mut stream.stream);
241+
}
242+
243+
let execution_index = schedules.first().expect("At least one stream").stream_index;
244+
let stream = unsafe { self.pool.get_mut_index(execution_index) };
245+
227246
// Find the maximum number of tasks across all schedules.
228247
let num_tasks_max = schedules
229248
.iter()
@@ -236,12 +255,13 @@ impl<B: SchedulerStreamBackend> SchedulerMultiStream<B> {
236255
for schedule in schedules.iter_mut() {
237256
// If there are tasks remaining in the schedule, enqueue the next one.
238257
if let Some(task) = schedule.tasks.next() {
239-
// Note: `unsafe` usage assumes valid index.
240-
let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) };
241258
B::enqueue(task, &mut stream.stream);
242259
}
243260
}
244261
}
262+
263+
// Making sure all tasks are registered to the queue.
264+
B::flush(&mut stream.stream);
245265
}
246266
}
247267

crates/cubecl-wgpu/src/compute/schedule.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ impl SchedulerStreamBackend for ScheduledWgpuBackend {
135135
stream.enqueue_task(task);
136136
}
137137

138+
fn flush(stream: &mut Self::Stream) {
139+
stream.flush();
140+
}
141+
138142
fn factory(&mut self) -> &mut Self::Factory {
139143
&mut self.factory
140144
}

0 commit comments

Comments
 (0)