Skip to content

Commit de1a201

Browse files
committed
use global + per key lock to thread protect autotune
1 parent 0766464 commit de1a201

File tree

2 files changed

+108
-33
lines changed

2 files changed

+108
-33
lines changed

python/test/unit/runtime/test_autotuner.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import threading
2+
import sys
3+
14
import torch
25

36
import triton
@@ -450,6 +453,61 @@ def grid(meta):
450453
exception_out_of_resource)
451454

452455

456+
def test_autotuner_thread_safety(device: str):
457+
if getattr(sys, "_is_gil_enabled", lambda: True)():
458+
pytest.skip("Requires running with the GIL disabled (PYTHON_GIL=0)")
459+
if not is_cuda():
460+
pytest.skip("CUDA backend is required for autotuner thread safety test")
461+
462+
num_threads = 8
463+
N = 1024
464+
src = torch.randn(N, device=device)
465+
dst = torch.empty_like(src)
466+
467+
configs = [
468+
triton.Config(kwargs={'BLOCK_SIZE': 32}),
469+
triton.Config(kwargs={'BLOCK_SIZE': 64}),
470+
]
471+
472+
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench)
473+
@triton.jit
474+
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
475+
pid = tl.program_id(0)
476+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
477+
mask = offsets < N
478+
data = tl.load(src + offsets, mask=mask)
479+
tl.store(dst + offsets, data, mask=mask)
480+
481+
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
482+
483+
ready_barrier = threading.Barrier(num_threads + 1)
484+
start_barrier = threading.Barrier(num_threads + 1)
485+
results = []
486+
487+
def worker():
488+
ready_barrier.wait()
489+
start_barrier.wait()
490+
try:
491+
_kernel[grid](dst, src, N)
492+
results.append(None)
493+
except Exception as exc: # pragma: no cover - captured for assertions
494+
results.append(exc)
495+
496+
threads = [threading.Thread(target=worker) for _ in range(num_threads)]
497+
498+
for thread in threads:
499+
thread.start()
500+
501+
ready_barrier.wait()
502+
start_barrier.wait()
503+
504+
for thread in threads:
505+
thread.join()
506+
507+
assert all(result is None for result in results)
508+
assert len(_kernel.cache) == 1
509+
510+
453511
def test_prune_all_configs(device):
454512
N = 1024
455513
src = torch.randn(N, device=device)

python/triton/runtime/autotuner.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import inspect
66
import hashlib
77
import json
8+
import threading
89
from functools import cached_property
910
from typing import Dict, Tuple, List, Optional
1011

@@ -35,6 +36,8 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
3536
self.configs = configs
3637
self.keys = key
3738
self.cache: Dict[Tuple, Config] = {}
39+
self._tuning_locks: Dict[Tuple, threading.Lock] = {}
40+
self._tuning_locks_guard = threading.Lock()
3841
self.arg_names = arg_names
3942
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
4043

@@ -125,7 +128,7 @@ def do_bench(self):
125128
return driver.active.get_benchmarker()
126129
return self._do_bench
127130

128-
def _bench(self, *args, config, **meta):
131+
def _bench(self, nargs, *args, config, **meta):
129132
from ..compiler.errors import CompileTimeAssertionFailure
130133

131134
verbose = knobs.autotuning.print
@@ -140,7 +143,7 @@ def _bench(self, *args, config, **meta):
140143
" Make sure that you don't re-define auto-tuned symbols.")
141144
# augment meta-parameters with tunable ones
142145
current = dict(meta, **config.all_kwargs())
143-
full_nargs = {**self.nargs, **current}
146+
full_nargs = {**nargs, **current}
144147

145148
def kernel_call():
146149
if config.pre_hook:
@@ -209,58 +212,73 @@ def check_disk_cache(self, tuning_key, configs, bench_fn):
209212
}), file_name, binary=False)
210213
return False
211214

215+
def _get_tuning_lock(self, key: Tuple):
216+
with self._tuning_locks_guard:
217+
lock = self._tuning_locks.get(key)
218+
if lock is None:
219+
lock = threading.Lock()
220+
self._tuning_locks[key] = lock
221+
return lock
222+
212223
def run(self, *args, **kwargs):
213-
self.nargs = dict(zip(self.arg_names, args))
224+
nargs = dict(zip(self.arg_names, args))
214225
used_cached_result = True
226+
key = None
215227
if len(self.configs) > 1:
216-
all_args = {**self.nargs, **kwargs}
228+
all_args = {**nargs, **kwargs}
217229
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
218-
key = [_args[key] for key in self.keys if key in _args]
230+
key_values = [_args[key_name] for key_name in self.keys if key_name in _args]
219231
for _, arg in _args.items():
220232
if hasattr(arg, "dtype"):
221-
key.append(str(arg.dtype))
222-
key = tuple(key)
233+
key_values.append(str(arg.dtype))
234+
key = tuple(key_values)
223235
if key not in self.cache:
224-
used_cached_result = False
225-
pruned_configs = self.prune_configs(kwargs)
226-
227-
def benchmark():
228-
bench_start = time.time()
229-
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
230-
bench_end = time.time()
231-
self.bench_time = bench_end - bench_start
232-
self.cache[key] = builtins.min(timings, key=timings.get)
233-
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
234-
self.pre_hook(full_nargs, reset_only=True)
235-
self.configs_timings = timings
236-
237-
if self.cache_results:
238-
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
239-
else:
240-
benchmark()
236+
lock = self._get_tuning_lock(key)
237+
with lock:
238+
if key not in self.cache:
239+
used_cached_result = False
240+
pruned_configs = self.prune_configs(kwargs, nargs)
241+
242+
def benchmark():
243+
bench_start = time.time()
244+
timings = {
245+
config: self._bench(nargs, *args, config=config, **kwargs)
246+
for config in pruned_configs
247+
}
248+
bench_end = time.time()
249+
self.bench_time = bench_end - bench_start
250+
best_config = builtins.min(timings, key=timings.get)
251+
self.cache[key] = best_config
252+
full_nargs = {**nargs, **kwargs, **best_config.all_kwargs()}
253+
self.pre_hook(full_nargs, reset_only=True)
254+
self.configs_timings = timings
255+
256+
if self.cache_results:
257+
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
258+
else:
259+
benchmark()
241260

242261
config = self.cache[key]
243262
else:
244263
config = self.configs[0]
245264
self.best_config = config
246-
if knobs.autotuning.print and not used_cached_result:
265+
if knobs.autotuning.print and key is not None and not used_cached_result:
247266
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
248267
f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
268+
full_nargs = {**nargs, **kwargs, **config.all_kwargs()}
249269
if config.pre_hook is not None:
250-
full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
251270
config.pre_hook(full_nargs)
252271
ret = self.fn.run(
253272
*args,
254273
**kwargs,
255274
**config.all_kwargs(),
256275
)
257-
self.nargs = None
258276
return ret
259277

260-
def prune_configs(self, kwargs: Dict) -> List[Config]:
278+
def prune_configs(self, kwargs: Dict, nargs: Dict) -> List[Config]:
261279
pruned_configs = self.configs
262280
if self.early_config_prune:
263-
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
281+
pruned_configs = self.early_config_prune(self.configs, nargs, **kwargs)
264282
if not pruned_configs:
265283
raise AutotunerError(
266284
"No valid autotuner configs after pruning. `early_config_prune` should return at least one config.")
@@ -275,7 +293,7 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
275293
if len(pruned_configs) > top_k:
276294
est_timing = {
277295
config: self.perf_model(
278-
**self.nargs,
296+
**nargs,
279297
**kwargs,
280298
**config.all_kwargs(),
281299
)
@@ -285,15 +303,14 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
285303
return pruned_configs
286304

287305
def warmup(self, *args, **kwargs):
288-
self.nargs = dict(zip(self.arg_names, args))
306+
nargs = dict(zip(self.arg_names, args))
289307
ret = []
290-
for autotune_config in self.prune_configs(kwargs):
308+
for autotune_config in self.prune_configs(kwargs, nargs):
291309
ret.append(self.fn.warmup(
292310
*args,
293311
**kwargs,
294312
**autotune_config.all_kwargs(),
295313
))
296-
self.nargs = None
297314
return ret
298315

299316

0 commit comments

Comments
 (0)