File tree Expand file tree Collapse file tree 1 file changed +15
-4
lines changed
crates/cubecl-runtime/src Expand file tree Collapse file tree 1 file changed +15
-4
lines changed Original file line number Diff line number Diff line change @@ -522,13 +522,14 @@ where
522522 // We use the same profiling lock to make sure no other task is currently using the current
523523 // device. Meaning that the current persistent memory strategy will only be used for the
524524 // provided function.
525-
526525 #[ cfg( multi_threading) ]
527526 let stream_id = self . profile_acquire ( ) ;
528527
529528 self . channel
530529 . allocation_mode ( MemoryAllocationMode :: Persistent , self . stream_id ( ) ) ;
530+
531531 let output = func ( input) ;
532+
532533 self . channel
533534 . allocation_mode ( MemoryAllocationMode :: Auto , self . stream_id ( ) ) ;
534535
@@ -656,11 +657,21 @@ where
656657 alloc
657658 }
658659
659- #[ cfg( not( multi_threading) ) ]
660- fn profile_guard ( & self ) { }
660+ /// Waits until the current device can be used without impacting profiling.
661+ ///
662+ /// All tasks registered on the same stream currently profiling won't wait, to allow recursive
663+ /// profiling.
664+ ///
665+ /// # Warning
666+ ///
667+ /// This function normally shouldn't be used except in internal code.
668+ pub fn profile_guard ( & self ) {
669+ #[ cfg( multi_threading) ]
670+ self . profile_guard_inner ( ) ;
671+ }
661672
662673 #[ cfg( multi_threading) ]
663- fn profile_guard ( & self ) {
674+ fn profile_guard_inner ( & self ) {
664675 let current = self . state . current_profiling . read ( ) ;
665676
666677 if let Some ( current_stream_id) = current. as_ref ( ) {
You can’t perform that action at this time.
0 commit comments