@@ -9,6 +9,7 @@ use core::{
99 fmt:: Display ,
1010 hash:: Hash ,
1111} ;
12+ use cubecl_common:: map:: SharedStateMap ;
1213use hashbrown:: HashMap ;
1314
1415#[ cfg( not( feature = "std" ) ) ]
@@ -17,7 +18,7 @@ use alloc::string::ToString;
1718/// A local tuner allows to create a tuner for a specific key that can be different from the server
1819/// key.
1920pub struct LocalTuner < AK : AutotuneKey , ID > {
20- state : spin :: Mutex < Option < HashMap < ID , Tuner < AK > > > > ,
21+ state : SharedStateMap < ID , Tuner < AK > > ,
2122 name : & ' static str ,
2223 sets : spin:: RwLock < Option < HashMap < TypeId , Arc < dyn Any + Send + Sync > > > > ,
2324}
4546 /// Create a new local tuner.
4647 pub const fn new ( name : & ' static str ) -> Self {
4748 Self {
48- state : spin :: Mutex :: new ( None ) ,
49+ state : SharedStateMap :: new ( ) ,
4950 name,
5051 sets : spin:: RwLock :: new ( None ) ,
5152 }
9495
9596 /// Clear the autotune state.
9697 pub fn clear ( & self ) {
97- let mut state = self . state . lock ( ) ;
98- * state = None ;
98+ self . state . clear ( )
9999 }
100100
101101 #[ cfg( feature = "autotune-checks" ) ]
@@ -131,20 +131,16 @@ where
131131
132132 // If this is cached and ready, use the operation.
133133 let autotune_job = {
134- let mut state = self . state . lock ( ) ;
135- let map = state. get_or_insert_with ( Default :: default) ;
136-
137- let tuner = match map. get_mut ( id) {
138- Some ( val) => val,
139- None => map. entry ( id. clone ( ) ) . or_insert_with ( move || {
140- let name = self . name . replace ( "::" , "-" ) ;
141- Tuner :: new ( & name, & id. to_string ( ) )
142- } ) ,
143- } ;
134+ let tuner_state = self . state . get_or_init ( id, move |id| {
135+ let name = self . name . replace ( "::" , "-" ) ;
136+ Tuner :: new ( & name, & id. to_string ( ) )
137+ } ) ;
138+ let tuner = tuner_state. read ( ) ;
144139
145- match tuner. fastest ( & key) {
140+ let mut tuner = match tuner. fastest ( & key) {
146141 TuneCacheResult :: Hit { fastest_index } => {
147- core:: mem:: drop ( state) ;
142+ core:: mem:: drop ( tuner) ;
143+ core:: mem:: drop ( tuner_state) ;
148144
149145 #[ cfg( feature = "autotune-checks" ) ]
150146 self . checks ( & operations, & inputs) ;
@@ -156,7 +152,8 @@ where
156152 return result;
157153 }
158154 TuneCacheResult :: Pending => {
159- core:: mem:: drop ( state) ;
155+ core:: mem:: drop ( tuner) ;
156+ core:: mem:: drop ( tuner_state) ;
160157
161158 let op = operations. fastest ( 0 ) ;
162159 let result = op
@@ -166,26 +163,38 @@ where
166163 }
167164 #[ cfg( std_io) ]
168165 TuneCacheResult :: Unchecked => {
166+ core:: mem:: drop ( tuner) ;
167+ let mut tuner = tuner_state. write ( ) ;
168+
169169 // If the cache checksum hasn't been checked, do so now, and retry.
170170 let checksum = operations. compute_checksum ( ) ;
171171 tuner. validate_checksum ( & key, & checksum) ;
172172
173173 // Check if with validation we can use its result
174174 if let TuneCacheResult :: Hit { fastest_index } = tuner. fastest ( & key) {
175- core:: mem:: drop ( state) ;
175+ core:: mem:: drop ( tuner) ;
176+ core:: mem:: drop ( tuner_state) ;
176177
177178 let op = operations. fastest ( fastest_index) ;
178179 let result = op
179180 . execute ( inputs)
180181 . expect ( "Should run when selected by autotune." ) ;
181182 return result;
182183 }
184+
185+ tuner
183186 }
184187
185188 #[ cfg( not( std_io) ) ]
186- TuneCacheResult :: Unchecked => ( ) ,
187- TuneCacheResult :: Miss => ( ) ,
188- }
189+ TuneCacheResult :: Unchecked => {
190+ core:: mem:: drop ( tuner) ;
191+ tuner_state. write ( )
192+ }
193+ TuneCacheResult :: Miss => {
194+ core:: mem:: drop ( tuner) ;
195+ tuner_state. write ( )
196+ }
197+ } ;
189198
190199 if tuner. autotuning . contains ( & key) {
191200 Box :: new ( move || { } )
@@ -198,9 +207,8 @@ where
198207 autotune_job ( ) ;
199208
200209 let index_to_run = {
201- let mut state = self . state . lock ( ) ;
202- let map = state. as_mut ( ) . unwrap ( ) ;
203- let tuner = map. get_mut ( id) . unwrap ( ) ;
210+ let tuner_state = self . state . get ( id) . unwrap ( ) ;
211+ let mut tuner = tuner_state. write ( ) ;
204212
205213 tuner. handle_results ( ) ;
206214
0 commit comments