1717from triton ._C .libtriton import get_cache_invalidating_env_vars
1818
1919
20- class AutotunerThreadState :
21- """Per-thread autotune cache and metadata."""
22-
20+ class CacheFuture :
2321 def __init__ (self ):
24- self .cache : Dict [Tuple , Config ] = {}
25- self .configs_timings : Dict [Config , List [float ]] | None = None
22+ self .event = threading .Event ()
23+ self .config : Config | None = None
24+ self .error : BaseException | None = None
25+ self .used_cached_result : bool = True
2626 self .bench_time : float | None = None
2727
2828
@@ -44,7 +44,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
4444 else :
4545 self .configs = configs
4646 self .keys = key
47- self ._thread_state = threading .local ()
47+ self ._cache : Dict [Tuple , Config ] = {}
48+ self ._cache_lock = threading .RLock ()
49+ self ._cache_futures : Dict [Tuple , CacheFuture ] = {}
4850 self .arg_names = arg_names
4951 self .cache_results = cache_results or (knobs .autotuning .cache and not knobs .runtime .interpret )
5052
@@ -135,13 +137,6 @@ def do_bench(self):
135137 return driver .active .get_benchmarker ()
136138 return self ._do_bench
137139
138- def _get_thread_state (self ) -> AutotunerThreadState :
139- state = getattr (self ._thread_state , "value" , None )
140- if state is None :
141- state = AutotunerThreadState ()
142- self ._thread_state .value = state
143- return state
144-
145140 def _bench (self , nargs , * args , config , ** meta ):
146141 from ..compiler .errors import CompileTimeAssertionFailure
147142
@@ -184,11 +179,11 @@ def kernel_call():
184179 print (f"Autotuning failed with { e } " )
185180 return [float ("inf" ), float ("inf" ), float ("inf" )]
186181
187- def check_disk_cache (self , tuning_key , configs , bench_fn , state : AutotunerThreadState ):
182+ def check_disk_cache (self , tuning_key , configs , bench_fn ):
188183 # We can't serialize prehooks, so just give up and run the benchmarks.
189184 if not tuning_key or any (cfg .pre_hook for cfg in configs ):
190- bench_fn ()
191- return False
185+ configs_timings , bench_time , best_config = bench_fn ()
186+ return False , bench_time , configs_timings , best_config
192187
193188 from triton .compiler .compiler import make_backend
194189
@@ -212,26 +207,82 @@ def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThread
212207 with open (path , "r" ) as cached_configs :
213208 timings = json .load (cached_configs )["configs_timings" ]
214209 timings = {Config (** config ): timing for config , timing in timings }
215- state .cache [tuning_key ] = builtins .min (timings , key = timings .get )
216- state .configs_timings = timings
217- return True
210+ best_config = builtins .min (timings , key = timings .get )
211+ return True , None , timings , best_config
218212
219- bench_fn ()
213+ configs_timings , bench_time , best_config = bench_fn ()
220214 cache .put (
221215 json .dumps ({
222216 "key" :
223217 tuning_key ,
224218 "configs_timings" : [(config .__dict__ , timings )
225- for config , timings in (state . configs_timings or {}).items ()
219+ for config , timings in (configs_timings or {}).items ()
226220 if not config .pre_hook ],
227221 }), file_name , binary = False )
228- return False
222+ return False , bench_time , configs_timings , best_config
223+
224+ def _get_config_for_key (self , key , nargs , args , kwargs ):
225+ with self ._cache_lock :
226+ cached = self ._cache .get (key )
227+ if cached is not None :
228+ return cached , True , None
229+
230+ future = self ._cache_futures .get (key )
231+ if future is None :
232+ future = CacheFuture ()
233+ self ._cache_futures [key ] = future
234+ runner = True
235+ else :
236+ runner = False
237+
238+ if not runner :
239+ future .event .wait ()
240+ if future .error is not None :
241+ raise future .error
242+ return future .config , future .used_cached_result , future .bench_time
243+
244+ pruned_configs = self .prune_configs (kwargs , nargs )
245+
246+ def benchmark ():
247+ bench_start = time .time ()
248+ timings = {config : self ._bench (nargs , * args , config = config , ** kwargs ) for config in pruned_configs }
249+ bench_duration = time .time () - bench_start
250+ best_config = builtins .min (timings , key = timings .get )
251+ full_nargs_local = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
252+ self .pre_hook (full_nargs_local , reset_only = True )
253+ return timings , bench_duration , best_config
254+
255+ used_cached_result = False
256+ bench_time = None
257+
258+ try :
259+ if self .cache_results :
260+ used_cached_result , bench_time , configs_timings , best_config = self .check_disk_cache (
261+ key , pruned_configs , benchmark )
262+ else :
263+ configs_timings , bench_time , best_config = benchmark ()
264+ used_cached_result = False
265+
266+ if best_config is not None :
267+ with self ._cache_lock :
268+ self ._cache [key ] = best_config
269+
270+ future .config = best_config
271+ future .used_cached_result = used_cached_result
272+ future .bench_time = bench_time
273+ return best_config , used_cached_result , bench_time
274+ except BaseException as exc :
275+ future .error = exc
276+ raise
277+ finally :
278+ future .event .set ()
279+ with self ._cache_lock :
280+ self ._cache_futures .pop (key , None )
229281
230282 def run (self , * args , ** kwargs ):
231- state = self ._get_thread_state ()
232- cache = state .cache
233283 nargs = dict (zip (self .arg_names , args ))
234284 used_cached_result = True
285+ bench_time = None
235286 key = None
236287 if len (self .configs ) > 1 :
237288 all_args = {** nargs , ** kwargs }
@@ -241,34 +292,14 @@ def run(self, *args, **kwargs):
241292 if hasattr (arg , "dtype" ):
242293 key_values .append (str (arg .dtype ))
243294 key = tuple (key_values )
244- if key not in cache :
245- used_cached_result = False
246- pruned_configs = self .prune_configs (kwargs , nargs )
247-
248- def benchmark ():
249- bench_start = time .time ()
250- timings = {config : self ._bench (nargs , * args , config = config , ** kwargs ) for config in pruned_configs }
251- bench_end = time .time ()
252- state .bench_time = bench_end - bench_start
253- best_config = builtins .min (timings , key = timings .get )
254- cache [key ] = best_config
255- full_nargs_local = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
256- self .pre_hook (full_nargs_local , reset_only = True )
257- state .configs_timings = timings
258-
259- if self .cache_results :
260- used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark , state )
261- else :
262- benchmark ()
263-
264- config = cache [key ]
295+ config , used_cached_result , bench_time = self ._get_config_for_key (key , nargs , args , kwargs )
265296 else :
266297 config = self .configs [0 ]
267298 self .best_config = config
268299 if knobs .autotuning .print and key is not None and not used_cached_result :
269- bench_time = state . bench_time or 0.0
300+ bench_time_value = bench_time or 0.0
270301 print (f"Triton autotuning for function { self .base_fn .__name__ } ,\n with key as { key } ,\n "
271- f"finished after { bench_time :.2f} s,\n best config selected: { self .best_config } ;" )
302+ f"finished after { bench_time_value :.2f} s,\n best config selected: { self .best_config } ;" )
272303 full_nargs = {** nargs , ** kwargs , ** config .all_kwargs ()}
273304 if config .pre_hook is not None :
274305 config .pre_hook (full_nargs )
0 commit comments