55import inspect
66import hashlib
77import json
8+ import threading
89from functools import cached_property
910from typing import Dict , Tuple , List , Optional
1011
@@ -35,6 +36,8 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
3536 self .configs = configs
3637 self .keys = key
3738 self .cache : Dict [Tuple , Config ] = {}
39+ self ._tuning_locks : Dict [Tuple , threading .Lock ] = {}
40+ self ._tuning_locks_guard = threading .Lock ()
3841 self .arg_names = arg_names
3942 self .cache_results = cache_results or (knobs .autotuning .cache and not knobs .runtime .interpret )
4043
@@ -125,7 +128,7 @@ def do_bench(self):
125128 return driver .active .get_benchmarker ()
126129 return self ._do_bench
127130
128- def _bench (self , * args , config , ** meta ):
131+ def _bench (self , nargs , * args , config , ** meta ):
129132 from ..compiler .errors import CompileTimeAssertionFailure
130133
131134 verbose = knobs .autotuning .print
@@ -140,7 +143,7 @@ def _bench(self, *args, config, **meta):
140143 " Make sure that you don't re-define auto-tuned symbols." )
141144 # augment meta-parameters with tunable ones
142145 current = dict (meta , ** config .all_kwargs ())
143- full_nargs = {** self . nargs , ** current }
146+ full_nargs = {** nargs , ** current }
144147
145148 def kernel_call ():
146149 if config .pre_hook :
@@ -209,58 +212,73 @@ def check_disk_cache(self, tuning_key, configs, bench_fn):
209212 }), file_name , binary = False )
210213 return False
211214
215+ def _get_tuning_lock (self , key : Tuple ):
216+ with self ._tuning_locks_guard :
217+ lock = self ._tuning_locks .get (key )
218+ if lock is None :
219+ lock = threading .Lock ()
220+ self ._tuning_locks [key ] = lock
221+ return lock
222+
212223 def run (self , * args , ** kwargs ):
213- self . nargs = dict (zip (self .arg_names , args ))
224+ nargs = dict (zip (self .arg_names , args ))
214225 used_cached_result = True
226+ key = None
215227 if len (self .configs ) > 1 :
216- all_args = {** self . nargs , ** kwargs }
228+ all_args = {** nargs , ** kwargs }
217229 _args = {k : v for (k , v ) in all_args .items () if k in self .arg_names }
218- key = [_args [key ] for key in self .keys if key in _args ]
230+ key_values = [_args [key_name ] for key_name in self .keys if key_name in _args ]
219231 for _ , arg in _args .items ():
220232 if hasattr (arg , "dtype" ):
221- key .append (str (arg .dtype ))
222- key = tuple (key )
233+ key_values .append (str (arg .dtype ))
234+ key = tuple (key_values )
223235 if key not in self .cache :
224- used_cached_result = False
225- pruned_configs = self .prune_configs (kwargs )
226-
227- def benchmark ():
228- bench_start = time .time ()
229- timings = {config : self ._bench (* args , config = config , ** kwargs ) for config in pruned_configs }
230- bench_end = time .time ()
231- self .bench_time = bench_end - bench_start
232- self .cache [key ] = builtins .min (timings , key = timings .get )
233- full_nargs = {** self .nargs , ** kwargs , ** self .cache [key ].all_kwargs ()}
234- self .pre_hook (full_nargs , reset_only = True )
235- self .configs_timings = timings
236-
237- if self .cache_results :
238- used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark )
239- else :
240- benchmark ()
236+ lock = self ._get_tuning_lock (key )
237+ with lock :
238+ if key not in self .cache :
239+ used_cached_result = False
240+ pruned_configs = self .prune_configs (kwargs , nargs )
241+
242+ def benchmark ():
243+ bench_start = time .time ()
244+ timings = {
245+ config : self ._bench (nargs , * args , config = config , ** kwargs )
246+ for config in pruned_configs
247+ }
248+ bench_end = time .time ()
249+ self .bench_time = bench_end - bench_start
250+ best_config = builtins .min (timings , key = timings .get )
251+ self .cache [key ] = best_config
252+ full_nargs = {** nargs , ** kwargs , ** best_config .all_kwargs ()}
253+ self .pre_hook (full_nargs , reset_only = True )
254+ self .configs_timings = timings
255+
256+ if self .cache_results :
257+ used_cached_result = self .check_disk_cache (key , pruned_configs , benchmark )
258+ else :
259+ benchmark ()
241260
242261 config = self .cache [key ]
243262 else :
244263 config = self .configs [0 ]
245264 self .best_config = config
246- if knobs .autotuning .print and not used_cached_result :
265+ if knobs .autotuning .print and key is not None and not used_cached_result :
247266 print (f"Triton autotuning for function { self .base_fn .__name__ } ,\n with key as { key } ,\n "
248267 f"finished after { self .bench_time :.2f} s,\n best config selected: { self .best_config } ;" )
268+ full_nargs = {** nargs , ** kwargs , ** config .all_kwargs ()}
249269 if config .pre_hook is not None :
250- full_nargs = {** self .nargs , ** kwargs , ** config .all_kwargs ()}
251270 config .pre_hook (full_nargs )
252271 ret = self .fn .run (
253272 * args ,
254273 ** kwargs ,
255274 ** config .all_kwargs (),
256275 )
257- self .nargs = None
258276 return ret
259277
260- def prune_configs (self , kwargs : Dict ) -> List [Config ]:
278+ def prune_configs (self , kwargs : Dict , nargs : Dict ) -> List [Config ]:
261279 pruned_configs = self .configs
262280 if self .early_config_prune :
263- pruned_configs = self .early_config_prune (self .configs , self . nargs , ** kwargs )
281+ pruned_configs = self .early_config_prune (self .configs , nargs , ** kwargs )
264282 if not pruned_configs :
265283 raise AutotunerError (
266284 "No valid autotuner configs after pruning. `early_config_prune` should return at least one config." )
@@ -275,7 +293,7 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
275293 if len (pruned_configs ) > top_k :
276294 est_timing = {
277295 config : self .perf_model (
278- ** self . nargs ,
296+ ** nargs ,
279297 ** kwargs ,
280298 ** config .all_kwargs (),
281299 )
@@ -285,15 +303,14 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
285303 return pruned_configs
286304
287305 def warmup (self , * args , ** kwargs ):
288- self . nargs = dict (zip (self .arg_names , args ))
306+ nargs = dict (zip (self .arg_names , args ))
289307 ret = []
290- for autotune_config in self .prune_configs (kwargs ):
308+ for autotune_config in self .prune_configs (kwargs , nargs ):
291309 ret .append (self .fn .warmup (
292310 * args ,
293311 ** kwargs ,
294312 ** autotune_config .all_kwargs (),
295313 ))
296- self .nargs = None
297314 return ret
298315
299316
0 commit comments