Skip to content

Commit cb06283

Browse files
committed
use global cache with per key future events for multi-thread sync on same key/kernel autotune
1 parent a12e370 commit cb06283

File tree

1 file changed

+78
-47
lines changed

1 file changed

+78
-47
lines changed

python/triton/runtime/autotuner.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from triton._C.libtriton import get_cache_invalidating_env_vars
1818

1919

20-
class AutotunerThreadState:
21-
"""Per-thread autotune cache and metadata."""
22-
20+
class CacheFuture:
2321
def __init__(self):
24-
self.cache: Dict[Tuple, Config] = {}
25-
self.configs_timings: Dict[Config, List[float]] | None = None
22+
self.event = threading.Event()
23+
self.config: Config | None = None
24+
self.error: BaseException | None = None
25+
self.used_cached_result: bool = True
2626
self.bench_time: float | None = None
2727

2828

@@ -44,7 +44,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
4444
else:
4545
self.configs = configs
4646
self.keys = key
47-
self._thread_state = threading.local()
47+
self._cache: Dict[Tuple, Config] = {}
48+
self._cache_lock = threading.RLock()
49+
self._cache_futures: Dict[Tuple, CacheFuture] = {}
4850
self.arg_names = arg_names
4951
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
5052

@@ -135,13 +137,6 @@ def do_bench(self):
135137
return driver.active.get_benchmarker()
136138
return self._do_bench
137139

138-
def _get_thread_state(self) -> AutotunerThreadState:
139-
state = getattr(self._thread_state, "value", None)
140-
if state is None:
141-
state = AutotunerThreadState()
142-
self._thread_state.value = state
143-
return state
144-
145140
def _bench(self, nargs, *args, config, **meta):
146141
from ..compiler.errors import CompileTimeAssertionFailure
147142

@@ -184,11 +179,11 @@ def kernel_call():
184179
print(f"Autotuning failed with {e}")
185180
return [float("inf"), float("inf"), float("inf")]
186181

187-
def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThreadState):
182+
def check_disk_cache(self, tuning_key, configs, bench_fn):
188183
# We can't serialize prehooks, so just give up and run the benchmarks.
189184
if not tuning_key or any(cfg.pre_hook for cfg in configs):
190-
bench_fn()
191-
return False
185+
configs_timings, bench_time, best_config = bench_fn()
186+
return False, bench_time, configs_timings, best_config
192187

193188
from triton.compiler.compiler import make_backend
194189

@@ -212,26 +207,82 @@ def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThread
212207
with open(path, "r") as cached_configs:
213208
timings = json.load(cached_configs)["configs_timings"]
214209
timings = {Config(**config): timing for config, timing in timings}
215-
state.cache[tuning_key] = builtins.min(timings, key=timings.get)
216-
state.configs_timings = timings
217-
return True
210+
best_config = builtins.min(timings, key=timings.get)
211+
return True, None, timings, best_config
218212

219-
bench_fn()
213+
configs_timings, bench_time, best_config = bench_fn()
220214
cache.put(
221215
json.dumps({
222216
"key":
223217
tuning_key,
224218
"configs_timings": [(config.__dict__, timings)
225-
for config, timings in (state.configs_timings or {}).items()
219+
for config, timings in (configs_timings or {}).items()
226220
if not config.pre_hook],
227221
}), file_name, binary=False)
228-
return False
222+
return False, bench_time, configs_timings, best_config
223+
224+
def _get_config_for_key(self, key, nargs, args, kwargs):
225+
with self._cache_lock:
226+
cached = self._cache.get(key)
227+
if cached is not None:
228+
return cached, True, None
229+
230+
future = self._cache_futures.get(key)
231+
if future is None:
232+
future = CacheFuture()
233+
self._cache_futures[key] = future
234+
runner = True
235+
else:
236+
runner = False
237+
238+
if not runner:
239+
future.event.wait()
240+
if future.error is not None:
241+
raise future.error
242+
return future.config, future.used_cached_result, future.bench_time
243+
244+
pruned_configs = self.prune_configs(kwargs, nargs)
245+
246+
def benchmark():
247+
bench_start = time.time()
248+
timings = {config: self._bench(nargs, *args, config=config, **kwargs) for config in pruned_configs}
249+
bench_duration = time.time() - bench_start
250+
best_config = builtins.min(timings, key=timings.get)
251+
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
252+
self.pre_hook(full_nargs_local, reset_only=True)
253+
return timings, bench_duration, best_config
254+
255+
used_cached_result = False
256+
bench_time = None
257+
258+
try:
259+
if self.cache_results:
260+
used_cached_result, bench_time, configs_timings, best_config = self.check_disk_cache(
261+
key, pruned_configs, benchmark)
262+
else:
263+
configs_timings, bench_time, best_config = benchmark()
264+
used_cached_result = False
265+
266+
if best_config is not None:
267+
with self._cache_lock:
268+
self._cache[key] = best_config
269+
270+
future.config = best_config
271+
future.used_cached_result = used_cached_result
272+
future.bench_time = bench_time
273+
return best_config, used_cached_result, bench_time
274+
except BaseException as exc:
275+
future.error = exc
276+
raise
277+
finally:
278+
future.event.set()
279+
with self._cache_lock:
280+
self._cache_futures.pop(key, None)
229281

230282
def run(self, *args, **kwargs):
231-
state = self._get_thread_state()
232-
cache = state.cache
233283
nargs = dict(zip(self.arg_names, args))
234284
used_cached_result = True
285+
bench_time = None
235286
key = None
236287
if len(self.configs) > 1:
237288
all_args = {**nargs, **kwargs}
@@ -241,34 +292,14 @@ def run(self, *args, **kwargs):
241292
if hasattr(arg, "dtype"):
242293
key_values.append(str(arg.dtype))
243294
key = tuple(key_values)
244-
if key not in cache:
245-
used_cached_result = False
246-
pruned_configs = self.prune_configs(kwargs, nargs)
247-
248-
def benchmark():
249-
bench_start = time.time()
250-
timings = {config: self._bench(nargs, *args, config=config, **kwargs) for config in pruned_configs}
251-
bench_end = time.time()
252-
state.bench_time = bench_end - bench_start
253-
best_config = builtins.min(timings, key=timings.get)
254-
cache[key] = best_config
255-
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
256-
self.pre_hook(full_nargs_local, reset_only=True)
257-
state.configs_timings = timings
258-
259-
if self.cache_results:
260-
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark, state)
261-
else:
262-
benchmark()
263-
264-
config = cache[key]
295+
config, used_cached_result, bench_time = self._get_config_for_key(key, nargs, args, kwargs)
265296
else:
266297
config = self.configs[0]
267298
self.best_config = config
268299
if knobs.autotuning.print and key is not None and not used_cached_result:
269-
bench_time = state.bench_time or 0.0
300+
bench_time_value = bench_time or 0.0
270301
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
271-
f"finished after {bench_time:.2f}s,\nbest config selected: {self.best_config};")
302+
f"finished after {bench_time_value:.2f}s,\nbest config selected: {self.best_config};")
272303
full_nargs = {**nargs, **kwargs, **config.all_kwargs()}
273304
if config.pre_hook is not None:
274305
config.pre_hook(full_nargs)

0 commit comments

Comments
 (0)