1717from triton ._C .libtriton import get_cache_invalidating_env_vars
1818
1919
20+ class AutotunerThreadState :
21+ """
22+ Per thread (thread-local) auto-tuner cache/configs
23+ """
24+
25+ __slots__ = ("cache" , "configs_timings" , "bench_time" )
26+
27+ def __init__ (self ):
28+ self .cache : Dict [Tuple , Config ] = {}
29+ self .configs_timings : Dict [Config , List [float ]] | None = None
30+ self .bench_time : float | None = None
31+
32+
2033class Autotuner (KernelInterface ):
2134
2235 def __init__ (self , fn , arg_names , configs , key , reset_to_zero , restore_value , pre_hook = None , post_hook = None ,
@@ -35,9 +48,7 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
3548 else :
3649 self .configs = configs
3750 self .keys = key
38- self .cache : Dict [Tuple , Config ] = {}
39- self ._tuning_locks : Dict [Tuple , threading .Lock ] = {}
40- self ._tuning_locks_guard = threading .Lock ()
51+ self ._thread_state = threading .local ()
4152 self .arg_names = arg_names
4253 self .cache_results = cache_results or (knobs .autotuning .cache and not knobs .runtime .interpret )
4354
@@ -128,6 +139,13 @@ def do_bench(self):
128139 return driver .active .get_benchmarker ()
129140 return self ._do_bench
130141
142+ def _get_thread_state (self ) -> AutotunerThreadState :
143+ state = getattr (self ._thread_state , "value" , None )
144+ if state is None :
145+ state = AutotunerThreadState ()
146+ self ._thread_state .value = state
147+ return state
148+
131149 def _bench (self , nargs , * args , config , ** meta ):
132150 from ..compiler .errors import CompileTimeAssertionFailure
133151
@@ -170,7 +188,7 @@ def kernel_call():
170188 print (f"Autotuning failed with { e } " )
171189 return [float ("inf" ), float ("inf" ), float ("inf" )]
172190
173- def check_disk_cache (self , tuning_key , configs , bench_fn ):
191+ def check_disk_cache (self , tuning_key , configs , bench_fn , state : AutotunerThreadState ):
174192 # We can't serialize prehooks, so just give up and run the benchmarks.
175193 if not tuning_key or any (cfg .pre_hook for cfg in configs ):
176194 bench_fn ()
@@ -198,8 +216,8 @@ def check_disk_cache(self, tuning_key, configs, bench_fn):
198216 with open (path , "r" ) as cached_configs :
199217 timings = json .load (cached_configs )["configs_timings" ]
200218 timings = {Config (** config ): timing for config , timing in timings }
201- self .cache [tuning_key ] = builtins .min (timings , key = timings .get )
202- self .configs_timings = timings
219+ state .cache [tuning_key ] = builtins .min (timings , key = timings .get )
220+ state .configs_timings = timings
203221 return True
204222
205223 bench_fn ()
@@ -208,19 +226,17 @@ def check_disk_cache(self, tuning_key, configs, bench_fn):
208226 "key" :
209227 tuning_key ,
210228 "configs_timings" :
211- [(config .__dict__ , timings ) for config , timings in self .configs_timings .items () if not config .pre_hook ],
229+ [
230+ (config .__dict__ , timings )
231+ for config , timings in (state .configs_timings or {}).items ()
232+ if not config .pre_hook
233+ ],
212234 }), file_name , binary = False )
213235 return False
214236
215- def _get_tuning_lock (self , key : Tuple ):
216- with self ._tuning_locks_guard :
217- lock = self ._tuning_locks .get (key )
218- if lock is None :
219- lock = threading .Lock ()
220- self ._tuning_locks [key ] = lock
221- return lock
222-
223237 def run (self , * args , ** kwargs ):
238+ state = self ._get_thread_state ()
239+ cache = state .cache
224240 nargs = dict (zip (self .arg_names , args ))
225241 used_cached_result = True
226242 key = None
@@ -232,39 +248,37 @@ def run(self, *args, **kwargs):
232248 if hasattr (arg , "dtype" ):
233249 key_values .append (str (arg .dtype ))
234250 key = tuple (key_values )
235- if key not in self .cache :
236- lock = self ._get_tuning_lock (key )
237- with lock :
238- if key not in self .cache :
239- used_cached_result = False
240- pruned_configs = self .prune_configs (kwargs , nargs )
241-
242- def benchmark ():
243- bench_start = time .time ()
244- timings = {
245- config : self ._bench (nargs , * args , config = config , ** kwargs )
246- for config in pruned_configs
247- }
248- bench_end = time .time ()
249- self .bench_time = bench_end - bench_start
250- best_config = builtins .min (timings , key = timings .get )
251- self .cache [key ] = best_config
252- full_nargs = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
253- self .pre_hook (full_nargs , reset_only = True )
254- self .configs_timings = timings
255-
256- if self .cache_results :
257- used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark )
258- else :
259- benchmark ()
260-
261- config = self .cache [key ]
251+ if key not in cache :
252+ used_cached_result = False
253+ pruned_configs = self .prune_configs (kwargs , nargs )
254+
255+ def benchmark ():
256+ bench_start = time .time ()
257+ timings = {
258+ config : self ._bench (nargs , * args , config = config , ** kwargs )
259+ for config in pruned_configs
260+ }
261+ bench_end = time .time ()
262+ state .bench_time = bench_end - bench_start
263+ best_config = builtins .min (timings , key = timings .get )
264+ cache [key ] = best_config
265+ full_nargs_local = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
266+ self .pre_hook (full_nargs_local , reset_only = True )
267+ state .configs_timings = timings
268+
269+ if self .cache_results :
270+ used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark , state )
271+ else :
272+ benchmark ()
273+
274+ config = cache [key ]
262275 else :
263276 config = self .configs [0 ]
264277 self .best_config = config
265278 if knobs .autotuning .print and key is not None and not used_cached_result :
279+ bench_time = state .bench_time or 0.0
266280 print (f"Triton autotuning for function { self .base_fn .__name__ } ,\n with key as { key } ,\n "
267- f"finished after { self . bench_time :.2f} s,\n best config selected: { self .best_config } ;" )
281+ f"finished after { bench_time :.2f} s,\n best config selected: { self .best_config } ;" )
268282 full_nargs = {** nargs , ** kwargs , ** config .all_kwargs ()}
269283 if config .pre_hook is not None :
270284 config .pre_hook (full_nargs )
0 commit comments