Skip to content

Commit f5d62c3

Browse files
committed
use per-key locked global cache
1 parent f7e3464 commit f5d62c3

File tree

1 file changed

+50
-48
lines changed

1 file changed

+50
-48
lines changed

python/triton/runtime/autotuner.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,6 @@
1717
from triton._C.libtriton import get_cache_invalidating_env_vars
1818

1919

20-
class AutotunerThreadState:
21-
"""Per-thread autotune cache and metadata."""
22-
23-
def __init__(self):
24-
self.cache: Dict[Tuple, Config] = {}
25-
self.configs_timings: Dict[Config, List[float]] | None = None
26-
self.bench_time: float | None = None
27-
28-
2920
class Autotuner(KernelInterface):
3021

3122
def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
@@ -44,7 +35,11 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
4435
else:
4536
self.configs = configs
4637
self.keys = key
47-
self._thread_state = threading.local()
38+
self.cache: Dict[Tuple, Config] = {}
39+
self.configs_timings: Dict[Config, List[float]] = {}
40+
self.bench_time: Optional[float] = None
41+
self._tuning_locks: Dict[Tuple, threading.Lock] = {}
42+
self._tuning_locks_guard = threading.Lock()
4843
self.arg_names = arg_names
4944
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
5045

@@ -135,13 +130,6 @@ def do_bench(self):
135130
return driver.active.get_benchmarker()
136131
return self._do_bench
137132

138-
def _get_thread_state(self) -> AutotunerThreadState:
139-
state = getattr(self._thread_state, "value", None)
140-
if state is None:
141-
state = AutotunerThreadState()
142-
self._thread_state.value = state
143-
return state
144-
145133
def _bench(self, nargs, *args, config, **meta):
146134
from ..compiler.errors import CompileTimeAssertionFailure
147135

@@ -184,7 +172,7 @@ def kernel_call():
184172
print(f"Autotuning failed with {e}")
185173
return [float("inf"), float("inf"), float("inf")]
186174

187-
def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThreadState):
175+
def check_disk_cache(self, tuning_key, configs, bench_fn):
188176
# We can't serialize prehooks, so just give up and run the benchmarks.
189177
if not tuning_key or any(cfg.pre_hook for cfg in configs):
190178
bench_fn()
@@ -212,24 +200,32 @@ def check_disk_cache(self, tuning_key, configs, bench_fn, state: AutotunerThread
212200
with open(path, "r") as cached_configs:
213201
timings = json.load(cached_configs)["configs_timings"]
214202
timings = {Config(**config): timing for config, timing in timings}
215-
state.cache[tuning_key] = builtins.min(timings, key=timings.get)
216-
state.configs_timings = timings
203+
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
204+
self.configs_timings = timings
217205
return True
218206

219207
bench_fn()
220208
cache.put(
221209
json.dumps({
222210
"key":
223211
tuning_key,
224-
"configs_timings": [(config.__dict__, timings)
225-
for config, timings in (state.configs_timings or {}).items()
226-
if not config.pre_hook],
227-
}), file_name, binary=False)
212+
"configs_timings":
213+
[(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
214+
}),
215+
file_name,
216+
binary=False,
217+
)
228218
return False
229219

220+
def _get_tuning_lock(self, key: Tuple):
221+
with self._tuning_locks_guard:
222+
lock = self._tuning_locks.get(key)
223+
if lock is None:
224+
lock = threading.Lock()
225+
self._tuning_locks[key] = lock
226+
return lock
227+
230228
def run(self, *args, **kwargs):
231-
state = self._get_thread_state()
232-
cache = state.cache
233229
nargs = dict(zip(self.arg_names, args))
234230
used_cached_result = True
235231
key = None
@@ -241,32 +237,38 @@ def run(self, *args, **kwargs):
241237
if hasattr(arg, "dtype"):
242238
key_values.append(str(arg.dtype))
243239
key = tuple(key_values)
244-
if key not in cache:
245-
used_cached_result = False
246-
pruned_configs = self.prune_configs(kwargs, nargs)
247-
248-
def benchmark():
249-
bench_start = time.time()
250-
timings = {config: self._bench(nargs, *args, config=config, **kwargs) for config in pruned_configs}
251-
bench_end = time.time()
252-
state.bench_time = bench_end - bench_start
253-
best_config = builtins.min(timings, key=timings.get)
254-
cache[key] = best_config
255-
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
256-
self.pre_hook(full_nargs_local, reset_only=True)
257-
state.configs_timings = timings
258-
259-
if self.cache_results:
260-
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark, state)
261-
else:
262-
benchmark()
263-
264-
config = cache[key]
240+
if key not in self.cache:
241+
lock = self._get_tuning_lock(key)
242+
with lock:
243+
if key not in self.cache:
244+
used_cached_result = False
245+
pruned_configs = self.prune_configs(kwargs, nargs)
246+
247+
def benchmark():
248+
bench_start = time.time()
249+
timings = {
250+
config: self._bench(nargs, *args, config=config, **kwargs)
251+
for config in pruned_configs
252+
}
253+
bench_end = time.time()
254+
self.bench_time = bench_end - bench_start
255+
best_config = builtins.min(timings, key=timings.get)
256+
self.cache[key] = best_config
257+
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
258+
self.pre_hook(full_nargs_local, reset_only=True)
259+
self.configs_timings = timings
260+
261+
if self.cache_results:
262+
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
263+
else:
264+
benchmark()
265+
266+
config = self.cache[key]
265267
else:
266268
config = self.configs[0]
267269
self.best_config = config
268270
if knobs.autotuning.print and key is not None and not used_cached_result:
269-
bench_time = state.bench_time or 0.0
271+
bench_time = self.bench_time or 0.0
270272
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
271273
f"finished after {bench_time:.2f}s,\nbest config selected: {self.best_config};")
272274
full_nargs = {**nargs, **kwargs, **config.all_kwargs()}

0 commit comments

Comments
 (0)