1717from triton ._C .libtriton import get_cache_invalidating_env_vars
1818
1919
20- class AutotunerThreadState :
21- """Per-thread autotune cache and metadata."""
20+ class CacheFuture :
2221
2322 def __init__ (self ):
24- self .cache : Dict [Tuple , Config ] = {}
25- self .configs_timings : Dict [Config , List [float ]] | None = None
23+ self .event = threading .Event ()
24+ self .config : Config | None = None
25+ self .error : BaseException | None = None
26+ self .used_cached_result : bool = True
2627 self .bench_time : float | None = None
2728
2829
@@ -44,7 +45,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
4445 else :
4546 self .configs = configs
4647 self .keys = key
47- self ._thread_state = threading .local ()
48+ self ._cache : Dict [Tuple , Config ] = {}
49+ self ._cache_lock = threading .RLock ()
50+ self ._cache_futures : Dict [Tuple , CacheFuture ] = {}
4851 self .arg_names = arg_names
4952 self .cache_results = cache_results or (knobs .autotuning .cache and not knobs .runtime .interpret )
5053
@@ -135,13 +138,6 @@ def do_bench(self):
135138 return driver .active .get_benchmarker ()
136139 return self ._do_bench
137140
138- def _get_thread_state (self ) -> AutotunerThreadState :
139- state = getattr (self ._thread_state , "value" , None )
140- if state is None :
141- state = AutotunerThreadState ()
142- self ._thread_state .value = state
143- return state
144-
145141 def _bench (self , nargs , * args , config , ** meta ):
146142 from ..compiler .errors import CompileTimeAssertionFailure
147143
@@ -184,11 +180,11 @@ def kernel_call():
184180 print (f"Autotuning failed with { e } " )
185181 return [float ("inf" ), float ("inf" ), float ("inf" )]
186182
187- def check_disk_cache (self , tuning_key , configs , bench_fn , state : AutotunerThreadState ):
183+ def check_disk_cache (self , tuning_key , configs , bench_fn ):
188184 # We can't serialize prehooks, so just give up and run the benchmarks.
189185 if not tuning_key or any (cfg .pre_hook for cfg in configs ):
190- bench_fn ()
191- return False
186+ configs_timings , bench_time , best_config = bench_fn ()
187+ return False , bench_time , configs_timings , best_config
192188
193189 from triton .compiler .compiler import make_backend
194190
@@ -212,26 +208,82 @@ def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThread
212208 with open (path , "r" ) as cached_configs :
213209 timings = json .load (cached_configs )["configs_timings" ]
214210 timings = {Config (** config ): timing for config , timing in timings }
215- state .cache [tuning_key ] = builtins .min (timings , key = timings .get )
216- state .configs_timings = timings
217- return True
211+ best_config = builtins .min (timings , key = timings .get )
212+ return True , None , timings , best_config
218213
219- bench_fn ()
214+ configs_timings , bench_time , best_config = bench_fn ()
220215 cache .put (
221216 json .dumps ({
222217 "key" :
223218 tuning_key ,
224219 "configs_timings" : [(config .__dict__ , timings )
225- for config , timings in (state . configs_timings or {}).items ()
220+ for config , timings in (configs_timings or {}).items ()
226221 if not config .pre_hook ],
227222 }), file_name , binary = False )
228- return False
223+ return False , bench_time , configs_timings , best_config
224+
225+ def _get_config_for_key (self , key , nargs , args , kwargs ):
226+ with self ._cache_lock :
227+ cached = self ._cache .get (key )
228+ if cached is not None :
229+ return cached , True , None
230+
231+ future = self ._cache_futures .get (key )
232+ if future is None :
233+ future = CacheFuture ()
234+ self ._cache_futures [key ] = future
235+ runner = True
236+ else :
237+ runner = False
238+
239+ if not runner :
240+ future .event .wait ()
241+ if future .error is not None :
242+ raise future .error
243+ return future .config , future .used_cached_result , future .bench_time
244+
245+ pruned_configs = self .prune_configs (kwargs , nargs )
246+
247+ def benchmark ():
248+ bench_start = time .time ()
249+ timings = {config : self ._bench (nargs , * args , config = config , ** kwargs ) for config in pruned_configs }
250+ bench_duration = time .time () - bench_start
251+ best_config = builtins .min (timings , key = timings .get )
252+ full_nargs_local = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
253+ self .pre_hook (full_nargs_local , reset_only = True )
254+ return timings , bench_duration , best_config
255+
256+ used_cached_result = False
257+ bench_time = None
258+
259+ try :
260+ if self .cache_results :
261+ used_cached_result , bench_time , configs_timings , best_config = self .check_disk_cache (
262+ key , pruned_configs , benchmark )
263+ else :
264+ configs_timings , bench_time , best_config = benchmark ()
265+ used_cached_result = False
266+
267+ if best_config is not None :
268+ with self ._cache_lock :
269+ self ._cache [key ] = best_config
270+
271+ future .config = best_config
272+ future .used_cached_result = used_cached_result
273+ future .bench_time = bench_time
274+ return best_config , used_cached_result , bench_time
275+ except BaseException as exc :
276+ future .error = exc
277+ raise
278+ finally :
279+ future .event .set ()
280+ with self ._cache_lock :
281+ self ._cache_futures .pop (key , None )
229282
230283 def run (self , * args , ** kwargs ):
231- state = self ._get_thread_state ()
232- cache = state .cache
233284 nargs = dict (zip (self .arg_names , args ))
234285 used_cached_result = True
286+ bench_time = None
235287 key = None
236288 if len (self .configs ) > 1 :
237289 all_args = {** nargs , ** kwargs }
@@ -241,34 +293,14 @@ def run(self, *args, **kwargs):
241293 if hasattr (arg , "dtype" ):
242294 key_values .append (str (arg .dtype ))
243295 key = tuple (key_values )
244- if key not in cache :
245- used_cached_result = False
246- pruned_configs = self .prune_configs (kwargs , nargs )
247-
248- def benchmark ():
249- bench_start = time .time ()
250- timings = {config : self ._bench (nargs , * args , config = config , ** kwargs ) for config in pruned_configs }
251- bench_end = time .time ()
252- state .bench_time = bench_end - bench_start
253- best_config = builtins .min (timings , key = timings .get )
254- cache [key ] = best_config
255- full_nargs_local = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
256- self .pre_hook (full_nargs_local , reset_only = True )
257- state .configs_timings = timings
258-
259- if self .cache_results :
260- used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark , state )
261- else :
262- benchmark ()
263-
264- config = cache [key ]
296+ config , used_cached_result , bench_time = self ._get_config_for_key (key , nargs , args , kwargs )
265297 else :
266298 config = self .configs [0 ]
267299 self .best_config = config
268300 if knobs .autotuning .print and key is not None and not used_cached_result :
269- bench_time = state . bench_time or 0.0
301+ bench_time_value = bench_time or 0.0
270302 print (f"Triton autotuning for function { self .base_fn .__name__ } ,\n with key as { key } ,\n "
271- f"finished after { bench_time :.2f} s,\n best config selected: { self .best_config } ;" )
303+ f"finished after { bench_time_value :.2f} s,\n best config selected: { self .best_config } ;" )
272304 full_nargs = {** nargs , ** kwargs , ** config .all_kwargs ()}
273305 if config .pre_hook is not None :
274306 config .pre_hook (full_nargs )
0 commit comments