Skip to content

Commit 07cba4b

Browse files
Expose profile guard (#958)
1 parent aaa77ba commit 07cba4b

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

crates/cubecl-runtime/src/client.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -522,13 +522,14 @@ where
522522
// We use the same profiling lock to make sure no other task is currently using the current
523523
// device. Meaning that the current persistent memory strategy will only be used for the
524524
// provided function.
525-
526525
#[cfg(multi_threading)]
527526
let stream_id = self.profile_acquire();
528527

529528
self.channel
530529
.allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
530+
531531
let output = func(input);
532+
532533
self.channel
533534
.allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
534535

@@ -656,11 +657,21 @@ where
656657
alloc
657658
}
658659

659-
#[cfg(not(multi_threading))]
660-
fn profile_guard(&self) {}
660+
/// Waits until the current device can be used without impacting profiling.
661+
///
662+
/// All tasks registered on the same stream currently profiling won't wait, to allow recursive
663+
/// profiling.
664+
///
665+
/// # Warning
666+
///
667+
/// This function normally shouldn't be used except in internal code.
668+
pub fn profile_guard(&self) {
669+
#[cfg(multi_threading)]
670+
self.profile_guard_inner();
671+
}
661672

662673
#[cfg(multi_threading)]
663-
fn profile_guard(&self) {
674+
fn profile_guard_inner(&self) {
664675
let current = self.state.current_profiling.read();
665676

666677
if let Some(current_stream_id) = current.as_ref() {

0 commit comments

Comments
 (0)