Skip to content

Commit 1f98379

Browse files
authored
Cache autotune timings to disk (#6261)
Some users had expressed a desire to cache autotune results, both to speed up local iteration, and to avoid re-tuning when scaling up to large numbers of GPUs. This PR caches tuning timings in Triton's cache dir. Running locally on 03-matrix-multiplication.py, the time of later runs is greatly reduced: ``` % time python ./03-matrix-multiplication.py ... real 1m59.055s % time python ./03-matrix-multiplication.py ... real 0m13.794s ``` The cache key consists of: * system information (triton source, target info, env vars) * kernel source code (with dependences) * the values of the tuning keys (e.g. M/N/K in the matmul example) * the set of configs requested for tuning (so that we'll re-tune if the user changes tunings) If any configs have `pre_hook`s defined, we don't try caching at all, since the results could depend on arbitrary python code. A sampling of one of the cache entries (from the matmul tutorial) is: ``` % jq . $TRITON_CACHE_DIR/X5O...Q/matmul_kernel.autotune.json { "key": [ 3968, 3968, 3968, "torch.float8_e5m2", "torch.float8_e5m2", "torch.float16" ], "configs_timings": [ [ { "kwargs": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8 }, "num_warps": 8, "num_ctas": 1, "num_stages": 3, "maxnreg": null, "pre_hook": null }, [ 0.14316800236701965, 0.1420159935951233, 0.14431999623775482 ] ], ... ``` It's not strictly necessary to encode the key, since it's part of the hashed path, but I think it makes it easier to understand the cache contents for any dev who needs to do so. I considered a few different designs here: * storing just the best config versus all timings (I like having the raw data available in the cache from a dev perspective, but I could relent on this, it's an easy change) * allowing new configs to be added while re-using older cached ones (I got cold feet at the thought of mutating the cache) * storing all key+config+timings in a single cache file (convenient for analysis, but also requires mutating the cache)
1 parent 658b5b2 commit 1f98379

File tree

1 file changed

+84
-27
lines changed

1 file changed

+84
-27
lines changed

python/triton/runtime/autotuner.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import os
55
import time
66
import inspect
7+
import hashlib
8+
import json
79
from typing import Dict, Tuple, List, Optional
810

911
from .jit import KernelInterface
@@ -13,22 +15,9 @@
1315

1416
class Autotuner(KernelInterface):
1517

16-
def __init__(
17-
self,
18-
fn,
19-
arg_names,
20-
configs,
21-
key,
22-
reset_to_zero,
23-
restore_value,
24-
pre_hook=None,
25-
post_hook=None,
26-
prune_configs_by: Optional[Dict] = None,
27-
warmup=None,
28-
rep=None,
29-
use_cuda_graph=False,
30-
do_bench=None,
31-
):
18+
def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
19+
prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None,
20+
cache_results=False):
3221
"""
3322
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
3423
'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -42,6 +31,7 @@ def __init__(
4231
self.keys = key
4332
self.cache: Dict[Tuple, Config] = {}
4433
self.arg_names = arg_names
34+
self.cache_results = cache_results or os.getenv("TRITON_CACHE_AUTOTUNING", None) == "1"
4535

4636
# Reset to zero or restore values
4737
self.reset_to_zero = []
@@ -170,6 +160,50 @@ def kernel_call():
170160
print(f"Autotuning failed with {e}")
171161
return [float("inf"), float("inf"), float("inf")]
172162

163+
def check_disk_cache(self, tuning_key, configs, bench_fn):
164+
# We can't serialize prehooks, so just give up and run the benchmarks.
165+
if not tuning_key or any(cfg.pre_hook for cfg in configs):
166+
bench_fn()
167+
return
168+
169+
from triton._C.libtriton import get_cache_invalidating_env_vars
170+
from triton.compiler.compiler import make_backend, triton_key
171+
from triton.runtime.cache import get_cache_manager
172+
from triton.runtime.jit import JITFunction
173+
174+
fn = self.fn
175+
while not isinstance(fn, JITFunction):
176+
fn = fn.fn
177+
178+
env_vars = get_cache_invalidating_env_vars()
179+
cache_key = [
180+
triton_key(),
181+
make_backend(driver.active.get_current_target()).hash(),
182+
fn.cache_key,
183+
str(sorted(env_vars.items())),
184+
str(tuning_key),
185+
] + [str(c) for c in configs]
186+
cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
187+
cache = get_cache_manager(cache_key)
188+
file_name = f"{fn.__name__[:150]}.autotune.json"
189+
path = cache.get_file(file_name)
190+
if path:
191+
with open(path, "r") as cached_configs:
192+
timings = json.load(cached_configs)["configs_timings"]
193+
timings = {Config(**config): timing for config, timing in timings}
194+
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
195+
self.configs_timings = timings
196+
return
197+
198+
bench_fn()
199+
cache.put(
200+
json.dumps({
201+
"key":
202+
tuning_key,
203+
"configs_timings":
204+
[(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
205+
}), file_name, binary=False)
206+
173207
def run(self, *args, **kwargs):
174208
self.nargs = dict(zip(self.arg_names, args))
175209
used_cached_result = True
@@ -182,17 +216,24 @@ def run(self, *args, **kwargs):
182216
key.append(str(arg.dtype))
183217
key = tuple(key)
184218
if key not in self.cache:
185-
# prune configs
186219
used_cached_result = False
187220
pruned_configs = self.prune_configs(kwargs)
188-
bench_start = time.time()
189-
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
190-
bench_end = time.time()
191-
self.bench_time = bench_end - bench_start
192-
self.cache[key] = builtins.min(timings, key=timings.get)
193-
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
194-
self.pre_hook(full_nargs, reset_only=True)
195-
self.configs_timings = timings
221+
222+
def benchmark():
223+
bench_start = time.time()
224+
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
225+
bench_end = time.time()
226+
self.bench_time = bench_end - bench_start
227+
self.cache[key] = builtins.min(timings, key=timings.get)
228+
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
229+
self.pre_hook(full_nargs, reset_only=True)
230+
self.configs_timings = timings
231+
232+
if self.cache_results:
233+
self.check_disk_cache(key, pruned_configs, benchmark)
234+
else:
235+
benchmark()
236+
196237
config = self.cache[key]
197238
else:
198239
config = self.configs[0]
@@ -300,9 +341,23 @@ def __str__(self):
300341
res.append(f"maxnreg: {self.maxnreg}")
301342
return ", ".join(res)
302343

344+
def __hash__(self):
345+
return hash((*self.all_kwargs().items(), self.pre_hook))
346+
347+
def __eq__(self, other):
348+
self_tuple = tuple((
349+
*self.all_kwargs().items(),
350+
self.pre_hook,
351+
))
352+
other_tuple = tuple((
353+
*other.all_kwargs().items(),
354+
other.pre_hook,
355+
))
356+
return self_tuple == other_tuple
357+
303358

304359
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
305-
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
360+
warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False):
306361
"""
307362
Decorator for auto-tuning a :code:`triton.jit`'d function.
308363
@@ -356,12 +411,14 @@ def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
356411
:type rep: int
357412
:param do_bench: a benchmark function to measure the time of each run.
358413
:type do_bench: lambda fn, quantiles
414+
:param cache_results: whether to cache autotune timings to disk. Defaults to False.
415+
"type cache_results: bool
359416
"""
360417

361418
def decorator(fn):
362419
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
363420
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
364-
use_cuda_graph=use_cuda_graph, do_bench=do_bench)
421+
use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results)
365422

366423
return decorator
367424

0 commit comments

Comments
 (0)