Skip to content

Commit 6b0216c

Browse files
Feat/multi streams (#3775)
1 parent 8e3ca6d commit 6b0216c

File tree

3 files changed

+39
-39
lines changed

3 files changed

+39
-39
lines changed

Cargo.lock

Lines changed: 32 additions & 32 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ portable-atomic = { version = "1.11.1" }
169169
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
170170

171171
### For the main burn branch. ###
172-
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "98eb9d27624375fb11d4a97febd364b049933fdc" }
173-
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "98eb9d27624375fb11d4a97febd364b049933fdc" }
174-
cubecl-quant = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "98eb9d27624375fb11d4a97febd364b049933fdc" }
172+
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c038fb30634ab4b27bc55f49729eed91c38ecf55" }
173+
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c038fb30634ab4b27bc55f49729eed91c38ecf55" }
174+
cubecl-quant = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c038fb30634ab4b27bc55f49729eed91c38ecf55" }
175175
### For local development. ###
176176
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
177177
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

crates/burn-fusion/src/stream/multi.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ impl<R: FusionRuntime> MultiStream<R> {
195195
/// Drain a stream
196196
pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
197197
if let Some(stream) = self.streams.get_mut(&id) {
198+
let old = unsafe { StreamId::swap(id) };
198199
let num_executed = stream.queue.global.len();
199200
stream.processor.process(
200201
Segment::new(&mut stream.queue, handles),
@@ -208,6 +209,9 @@ impl<R: FusionRuntime> MultiStream<R> {
208209
let to_drop = self.shared_tensors.clear_tensors(cleared);
209210

210211
self.drop_shared_tensors(to_drop, handles, id);
212+
unsafe {
213+
StreamId::swap(old);
214+
};
211215
}
212216
}
213217

@@ -311,12 +315,8 @@ impl<R: FusionRuntime> MultiStream<R> {
311315
}
312316

313317
for id in streams_to_sync.drain() {
314-
let old = unsafe { StreamId::swap(id) };
315318
log::info!("Drain stream {id} for use in current {current}");
316319
self.resolve_stream(handles, id, nodes);
317-
unsafe {
318-
StreamId::swap(old);
319-
};
320320
}
321321
}
322322

0 commit comments

Comments
 (0)