1717from triton ._C .libtriton import get_cache_invalidating_env_vars
1818
1919
20- class AutotunerThreadState :
21- """Per-thread autotune cache and metadata."""
22-
23- def __init__ (self ):
24- self .cache : Dict [Tuple , Config ] = {}
25- self .configs_timings : Dict [Config , List [float ]] | None = None
26- self .bench_time : float | None = None
27-
28-
2920class Autotuner (KernelInterface ):
3021
3122 def __init__ (self , fn , arg_names , configs , key , reset_to_zero , restore_value , pre_hook = None , post_hook = None ,
@@ -44,7 +35,11 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
4435 else :
4536 self .configs = configs
4637 self .keys = key
47- self ._thread_state = threading .local ()
38+ self .cache : Dict [Tuple , Config ] = {}
39+ self .configs_timings : Dict [Config , List [float ]] = {}
40+ self .bench_time : Optional [float ] = None
41+ self ._tuning_locks : Dict [Tuple , threading .Lock ] = {}
42+ self ._tuning_locks_guard = threading .Lock ()
4843 self .arg_names = arg_names
4944 self .cache_results = cache_results or (knobs .autotuning .cache and not knobs .runtime .interpret )
5045
@@ -135,13 +130,6 @@ def do_bench(self):
135130 return driver .active .get_benchmarker ()
136131 return self ._do_bench
137132
138- def _get_thread_state (self ) -> AutotunerThreadState :
139- state = getattr (self ._thread_state , "value" , None )
140- if state is None :
141- state = AutotunerThreadState ()
142- self ._thread_state .value = state
143- return state
144-
145133 def _bench (self , nargs , * args , config , ** meta ):
146134 from ..compiler .errors import CompileTimeAssertionFailure
147135
@@ -184,7 +172,7 @@ def kernel_call():
184172 print (f"Autotuning failed with { e } " )
185173 return [float ("inf" ), float ("inf" ), float ("inf" )]
186174
187- def check_disk_cache (self , tuning_key , configs , bench_fn , state : AutotunerThreadState ):
175+ def check_disk_cache (self , tuning_key , configs , bench_fn ):
188176 # We can't serialize prehooks, so just give up and run the benchmarks.
189177 if not tuning_key or any (cfg .pre_hook for cfg in configs ):
190178 bench_fn ()
@@ -212,24 +200,32 @@ def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThread
212200 with open (path , "r" ) as cached_configs :
213201 timings = json .load (cached_configs )["configs_timings" ]
214202 timings = {Config (** config ): timing for config , timing in timings }
215- state .cache [tuning_key ] = builtins .min (timings , key = timings .get )
216- state .configs_timings = timings
203+ self .cache [tuning_key ] = builtins .min (timings , key = timings .get )
204+ self .configs_timings = timings
217205 return True
218206
219207 bench_fn ()
220208 cache .put (
221209 json .dumps ({
222210 "key" :
223211 tuning_key ,
224- "configs_timings" : [(config .__dict__ , timings )
225- for config , timings in (state .configs_timings or {}).items ()
226- if not config .pre_hook ],
227- }), file_name , binary = False )
212+ "configs_timings" :
213+ [(config .__dict__ , timings ) for config , timings in self .configs_timings .items () if not config .pre_hook ],
214+ }),
215+ file_name ,
216+ binary = False ,
217+ )
228218 return False
229219
220+ def _get_tuning_lock (self , key : Tuple ):
221+ with self ._tuning_locks_guard :
222+ lock = self ._tuning_locks .get (key )
223+ if lock is None :
224+ lock = threading .Lock ()
225+ self ._tuning_locks [key ] = lock
226+ return lock
227+
230228 def run (self , * args , ** kwargs ):
231- state = self ._get_thread_state ()
232- cache = state .cache
233229 nargs = dict (zip (self .arg_names , args ))
234230 used_cached_result = True
235231 key = None
@@ -241,32 +237,38 @@ def run(self, *args, **kwargs):
241237 if hasattr (arg , "dtype" ):
242238 key_values .append (str (arg .dtype ))
243239 key = tuple (key_values )
244- if key not in cache :
245- used_cached_result = False
246- pruned_configs = self .prune_configs (kwargs , nargs )
247-
248- def benchmark ():
249- bench_start = time .time ()
250- timings = {config : self ._bench (nargs , * args , config = config , ** kwargs ) for config in pruned_configs }
251- bench_end = time .time ()
252- state .bench_time = bench_end - bench_start
253- best_config = builtins .min (timings , key = timings .get )
254- cache [key ] = best_config
255- full_nargs_local = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
256- self .pre_hook (full_nargs_local , reset_only = True )
257- state .configs_timings = timings
258-
259- if self .cache_results :
260- used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark , state )
261- else :
262- benchmark ()
263-
264- config = cache [key ]
240+ if key not in self .cache :
241+ lock = self ._get_tuning_lock (key )
242+ with lock :
243+ if key not in self .cache :
244+ used_cached_result = False
245+ pruned_configs = self .prune_configs (kwargs , nargs )
246+
247+ def benchmark ():
248+ bench_start = time .time ()
249+ timings = {
250+ config : self ._bench (nargs , * args , config = config , ** kwargs )
251+ for config in pruned_configs
252+ }
253+ bench_end = time .time ()
254+ self .bench_time = bench_end - bench_start
255+ best_config = builtins .min (timings , key = timings .get )
256+ self .cache [key ] = best_config
257+ full_nargs_local = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
258+ self .pre_hook (full_nargs_local , reset_only = True )
259+ self .configs_timings = timings
260+
261+ if self .cache_results :
262+ used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark )
263+ else :
264+ benchmark ()
265+
266+ config = self .cache [key ]
265267 else :
266268 config = self .configs [0 ]
267269 self .best_config = config
268270 if knobs .autotuning .print and key is not None and not used_cached_result :
269- bench_time = state .bench_time or 0.0
271+ bench_time = self .bench_time or 0.0
270272 print (f"Triton autotuning for function { self .base_fn .__name__ } ,\n with key as { key } ,\n "
271273 f"finished after { bench_time :.2f} s,\n best config selected: { self .best_config } ;" )
272274 full_nargs = {** nargs , ** kwargs , ** config .all_kwargs ()}
0 commit comments