Skip to content

Commit 6923eec

Browse files
committed
switch to thread-local for both latency and accuracy
1 parent de1a201 commit 6923eec

File tree

1 file changed

+57
-43
lines changed

1 file changed

+57
-43
lines changed

python/triton/runtime/autotuner.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@
1717
from triton._C.libtriton import get_cache_invalidating_env_vars
1818

1919

20+
class AutotunerThreadState:
21+
"""
22+
Per thread (thread-local) auto-tuner cache/configs
23+
"""
24+
25+
__slots__ = ("cache", "configs_timings", "bench_time")
26+
27+
def __init__(self):
28+
self.cache: Dict[Tuple, Config] = {}
29+
self.configs_timings: Dict[Config, List[float]] | None = None
30+
self.bench_time: float | None = None
31+
32+
2033
class Autotuner(KernelInterface):
2134

2235
def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
@@ -35,9 +48,7 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
3548
else:
3649
self.configs = configs
3750
self.keys = key
38-
self.cache: Dict[Tuple, Config] = {}
39-
self._tuning_locks: Dict[Tuple, threading.Lock] = {}
40-
self._tuning_locks_guard = threading.Lock()
51+
self._thread_state = threading.local()
4152
self.arg_names = arg_names
4253
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
4354

@@ -128,6 +139,13 @@ def do_bench(self):
128139
return driver.active.get_benchmarker()
129140
return self._do_bench
130141

142+
def _get_thread_state(self) -> AutotunerThreadState:
143+
state = getattr(self._thread_state, "value", None)
144+
if state is None:
145+
state = AutotunerThreadState()
146+
self._thread_state.value = state
147+
return state
148+
131149
def _bench(self, nargs, *args, config, **meta):
132150
from ..compiler.errors import CompileTimeAssertionFailure
133151

@@ -170,7 +188,7 @@ def kernel_call():
170188
print(f"Autotuning failed with {e}")
171189
return [float("inf"), float("inf"), float("inf")]
172190

173-
def check_disk_cache(self, tuning_key, configs, bench_fn):
191+
def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThreadState):
174192
# We can't serialize prehooks, so just give up and run the benchmarks.
175193
if not tuning_key or any(cfg.pre_hook for cfg in configs):
176194
bench_fn()
@@ -198,8 +216,8 @@ def check_disk_cache(self, tuning_key, configs, bench_fn):
198216
with open(path, "r") as cached_configs:
199217
timings = json.load(cached_configs)["configs_timings"]
200218
timings = {Config(**config): timing for config, timing in timings}
201-
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
202-
self.configs_timings = timings
219+
state.cache[tuning_key] = builtins.min(timings, key=timings.get)
220+
state.configs_timings = timings
203221
return True
204222

205223
bench_fn()
@@ -208,19 +226,17 @@ def check_disk_cache(self, tuning_key, configs, bench_fn):
208226
"key":
209227
tuning_key,
210228
"configs_timings":
211-
[(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
229+
[
230+
(config.__dict__, timings)
231+
for config, timings in (state.configs_timings or {}).items()
232+
if not config.pre_hook
233+
],
212234
}), file_name, binary=False)
213235
return False
214236

215-
def _get_tuning_lock(self, key: Tuple):
216-
with self._tuning_locks_guard:
217-
lock = self._tuning_locks.get(key)
218-
if lock is None:
219-
lock = threading.Lock()
220-
self._tuning_locks[key] = lock
221-
return lock
222-
223237
def run(self, *args, **kwargs):
238+
state = self._get_thread_state()
239+
cache = state.cache
224240
nargs = dict(zip(self.arg_names, args))
225241
used_cached_result = True
226242
key = None
@@ -232,39 +248,37 @@ def run(self, *args, **kwargs):
232248
if hasattr(arg, "dtype"):
233249
key_values.append(str(arg.dtype))
234250
key = tuple(key_values)
235-
if key not in self.cache:
236-
lock = self._get_tuning_lock(key)
237-
with lock:
238-
if key not in self.cache:
239-
used_cached_result = False
240-
pruned_configs = self.prune_configs(kwargs, nargs)
241-
242-
def benchmark():
243-
bench_start = time.time()
244-
timings = {
245-
config: self._bench(nargs, *args, config=config, **kwargs)
246-
for config in pruned_configs
247-
}
248-
bench_end = time.time()
249-
self.bench_time = bench_end - bench_start
250-
best_config = builtins.min(timings, key=timings.get)
251-
self.cache[key] = best_config
252-
full_nargs = {**nargs, **kwargs, **best_config.all_kwargs()}
253-
self.pre_hook(full_nargs, reset_only=True)
254-
self.configs_timings = timings
255-
256-
if self.cache_results:
257-
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
258-
else:
259-
benchmark()
260-
261-
config = self.cache[key]
251+
if key not in cache:
252+
used_cached_result = False
253+
pruned_configs = self.prune_configs(kwargs, nargs)
254+
255+
def benchmark():
256+
bench_start = time.time()
257+
timings = {
258+
config: self._bench(nargs, *args, config=config, **kwargs)
259+
for config in pruned_configs
260+
}
261+
bench_end = time.time()
262+
state.bench_time = bench_end - bench_start
263+
best_config = builtins.min(timings, key=timings.get)
264+
cache[key] = best_config
265+
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
266+
self.pre_hook(full_nargs_local, reset_only=True)
267+
state.configs_timings = timings
268+
269+
if self.cache_results:
270+
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark, state)
271+
else:
272+
benchmark()
273+
274+
config = cache[key]
262275
else:
263276
config = self.configs[0]
264277
self.best_config = config
265278
if knobs.autotuning.print and key is not None and not used_cached_result:
279+
bench_time = state.bench_time or 0.0
266280
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
267-
f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
281+
f"finished after {bench_time:.2f}s,\nbest config selected: {self.best_config};")
268282
full_nargs = {**nargs, **kwargs, **config.all_kwargs()}
269283
if config.pre_hook is not None:
270284
config.pre_hook(full_nargs)

0 commit comments

Comments
 (0)