Skip to content

Commit 10a5601

Browse files
committed
use global cache with per key future events for multi-thread sync on same key/kernel autotune
1 parent a12e370 commit 10a5601

File tree

1 file changed

+78
-46
lines changed

1 file changed

+78
-46
lines changed

python/triton/runtime/autotuner.py

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
from triton._C.libtriton import get_cache_invalidating_env_vars
1818

1919

20-
class AutotunerThreadState:
21-
"""Per-thread autotune cache and metadata."""
20+
class CacheFuture:
2221

2322
def __init__(self):
24-
self.cache: Dict[Tuple, Config] = {}
25-
self.configs_timings: Dict[Config, List[float]] | None = None
23+
self.event = threading.Event()
24+
self.config: Config | None = None
25+
self.error: BaseException | None = None
26+
self.used_cached_result: bool = True
2627
self.bench_time: float | None = None
2728

2829

@@ -44,7 +45,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
4445
else:
4546
self.configs = configs
4647
self.keys = key
47-
self._thread_state = threading.local()
48+
self._cache: Dict[Tuple, Config] = {}
49+
self._cache_lock = threading.RLock()
50+
self._cache_futures: Dict[Tuple, CacheFuture] = {}
4851
self.arg_names = arg_names
4952
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
5053

@@ -135,13 +138,6 @@ def do_bench(self):
135138
return driver.active.get_benchmarker()
136139
return self._do_bench
137140

138-
def _get_thread_state(self) -> AutotunerThreadState:
139-
state = getattr(self._thread_state, "value", None)
140-
if state is None:
141-
state = AutotunerThreadState()
142-
self._thread_state.value = state
143-
return state
144-
145141
def _bench(self, nargs, *args, config, **meta):
146142
from ..compiler.errors import CompileTimeAssertionFailure
147143

@@ -184,11 +180,11 @@ def kernel_call():
184180
print(f"Autotuning failed with {e}")
185181
return [float("inf"), float("inf"), float("inf")]
186182

187-
def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThreadState):
183+
def check_disk_cache(self, tuning_key, configs, bench_fn):
188184
# We can't serialize prehooks, so just give up and run the benchmarks.
189185
if not tuning_key or any(cfg.pre_hook for cfg in configs):
190-
bench_fn()
191-
return False
186+
configs_timings, bench_time, best_config = bench_fn()
187+
return False, bench_time, configs_timings, best_config
192188

193189
from triton.compiler.compiler import make_backend
194190

@@ -212,26 +208,82 @@ def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThread
212208
with open(path, "r") as cached_configs:
213209
timings = json.load(cached_configs)["configs_timings"]
214210
timings = {Config(**config): timing for config, timing in timings}
215-
state.cache[tuning_key] = builtins.min(timings, key=timings.get)
216-
state.configs_timings = timings
217-
return True
211+
best_config = builtins.min(timings, key=timings.get)
212+
return True, None, timings, best_config
218213

219-
bench_fn()
214+
configs_timings, bench_time, best_config = bench_fn()
220215
cache.put(
221216
json.dumps({
222217
"key":
223218
tuning_key,
224219
"configs_timings": [(config.__dict__, timings)
225-
for config, timings in (state.configs_timings or {}).items()
220+
for config, timings in (configs_timings or {}).items()
226221
if not config.pre_hook],
227222
}), file_name, binary=False)
228-
return False
223+
return False, bench_time, configs_timings, best_config
224+
225+
def _get_config_for_key(self, key, nargs, args, kwargs):
226+
with self._cache_lock:
227+
cached = self._cache.get(key)
228+
if cached is not None:
229+
return cached, True, None
230+
231+
future = self._cache_futures.get(key)
232+
if future is None:
233+
future = CacheFuture()
234+
self._cache_futures[key] = future
235+
runner = True
236+
else:
237+
runner = False
238+
239+
if not runner:
240+
future.event.wait()
241+
if future.error is not None:
242+
raise future.error
243+
return future.config, future.used_cached_result, future.bench_time
244+
245+
pruned_configs = self.prune_configs(kwargs, nargs)
246+
247+
def benchmark():
248+
bench_start = time.time()
249+
timings = {config: self._bench(nargs, *args, config=config, **kwargs) for config in pruned_configs}
250+
bench_duration = time.time() - bench_start
251+
best_config = builtins.min(timings, key=timings.get)
252+
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
253+
self.pre_hook(full_nargs_local, reset_only=True)
254+
return timings, bench_duration, best_config
255+
256+
used_cached_result = False
257+
bench_time = None
258+
259+
try:
260+
if self.cache_results:
261+
used_cached_result, bench_time, configs_timings, best_config = self.check_disk_cache(
262+
key, pruned_configs, benchmark)
263+
else:
264+
configs_timings, bench_time, best_config = benchmark()
265+
used_cached_result = False
266+
267+
if best_config is not None:
268+
with self._cache_lock:
269+
self._cache[key] = best_config
270+
271+
future.config = best_config
272+
future.used_cached_result = used_cached_result
273+
future.bench_time = bench_time
274+
return best_config, used_cached_result, bench_time
275+
except BaseException as exc:
276+
future.error = exc
277+
raise
278+
finally:
279+
future.event.set()
280+
with self._cache_lock:
281+
self._cache_futures.pop(key, None)
229282

230283
def run(self, *args, **kwargs):
231-
state = self._get_thread_state()
232-
cache = state.cache
233284
nargs = dict(zip(self.arg_names, args))
234285
used_cached_result = True
286+
bench_time = None
235287
key = None
236288
if len(self.configs) > 1:
237289
all_args = {**nargs, **kwargs}
@@ -241,34 +293,14 @@ def run(self, *args, **kwargs):
241293
if hasattr(arg, "dtype"):
242294
key_values.append(str(arg.dtype))
243295
key = tuple(key_values)
244-
if key not in cache:
245-
used_cached_result = False
246-
pruned_configs = self.prune_configs(kwargs, nargs)
247-
248-
def benchmark():
249-
bench_start = time.time()
250-
timings = {config: self._bench(nargs, *args, config=config, **kwargs) for config in pruned_configs}
251-
bench_end = time.time()
252-
state.bench_time = bench_end - bench_start
253-
best_config = builtins.min(timings, key=timings.get)
254-
cache[key] = best_config
255-
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
256-
self.pre_hook(full_nargs_local, reset_only=True)
257-
state.configs_timings = timings
258-
259-
if self.cache_results:
260-
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark, state)
261-
else:
262-
benchmark()
263-
264-
config = cache[key]
296+
config, used_cached_result, bench_time = self._get_config_for_key(key, nargs, args, kwargs)
265297
else:
266298
config = self.configs[0]
267299
self.best_config = config
268300
if knobs.autotuning.print and key is not None and not used_cached_result:
269-
bench_time = state.bench_time or 0.0
301+
bench_time_value = bench_time or 0.0
270302
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
271-
f"finished after {bench_time:.2f}s,\nbest config selected: {self.best_config};")
303+
f"finished after {bench_time_value:.2f}s,\nbest config selected: {self.best_config};")
272304
full_nargs = {**nargs, **kwargs, **config.all_kwargs()}
273305
if config.pre_hook is not None:
274306
config.pre_hook(full_nargs)

0 commit comments

Comments
 (0)