Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions python/test/unit/runtime/test_autotuner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import threading
import sys

import torch

import triton
Expand Down Expand Up @@ -450,6 +453,61 @@ def grid(meta):
exception_out_of_resource)


def test_nogil_safety(device: str):
if getattr(sys, "_is_gil_enabled", lambda: True)():
pytest.skip("Requires running with the GIL disabled (PYTHON_GIL=0)")
if not is_cuda():
pytest.skip("CUDA backend is required for autotuner thread safety test")

num_threads = 8
N = 1024
src = torch.randn(N, device=device)
dst = torch.empty_like(src)

configs = [
triton.Config(kwargs={'BLOCK_SIZE': 32}),
triton.Config(kwargs={'BLOCK_SIZE': 64}),
]

@triton.autotune(configs=configs, key=['N'], do_bench=do_bench)
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
data = tl.load(src + offsets, mask=mask)
tl.store(dst + offsets, data, mask=mask)

grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )

ready_barrier = threading.Barrier(num_threads + 1)
start_barrier = threading.Barrier(num_threads + 1)
results = []

def worker():
ready_barrier.wait()
start_barrier.wait()
try:
_kernel[grid](dst, src, N)
results.append(None)
except Exception as exc:
results.append(exc)

threads = [threading.Thread(target=worker) for _ in range(num_threads)]

for thread in threads:
thread.start()

ready_barrier.wait()
start_barrier.wait()

for thread in threads:
thread.join()

assert all(result is None for result in results)
assert len(_kernel.cache) == 1


def test_prune_all_configs(device):
N = 1024
src = torch.randn(N, device=device)
Expand Down
147 changes: 100 additions & 47 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import hashlib
import json
import threading
from functools import cached_property
from typing import Dict, Tuple, List, Optional

Expand All @@ -16,6 +17,16 @@
from triton._C.libtriton import get_cache_invalidating_env_vars


class CacheFuture:

def __init__(self):
self.event = threading.Event()
self.config: Config | None = None
self.error: BaseException | None = None
self.used_cached_result: bool = True
self.bench_time: float | None = None


class Autotuner(KernelInterface):

def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
Expand All @@ -34,7 +45,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
else:
self.configs = configs
self.keys = key
self.cache: Dict[Tuple, Config] = {}
self._cache: Dict[Tuple, Config] = {}
self._cache_lock = threading.RLock()
self._cache_futures: Dict[Tuple, CacheFuture] = {}
self.arg_names = arg_names
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)

Expand Down Expand Up @@ -125,7 +138,7 @@ def do_bench(self):
return driver.active.get_benchmarker()
return self._do_bench

def _bench(self, *args, config, **meta):
def _bench(self, nargs, *args, config, **meta):
from ..compiler.errors import CompileTimeAssertionFailure

verbose = knobs.autotuning.print
Expand All @@ -140,7 +153,7 @@ def _bench(self, *args, config, **meta):
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.all_kwargs())
full_nargs = {**self.nargs, **current}
full_nargs = {**nargs, **current}

def kernel_call():
if config.pre_hook:
Expand Down Expand Up @@ -170,8 +183,8 @@ def kernel_call():
def check_disk_cache(self, tuning_key, configs, bench_fn):
# We can't serialize prehooks, so just give up and run the benchmarks.
if not tuning_key or any(cfg.pre_hook for cfg in configs):
bench_fn()
return False
configs_timings, bench_time, best_config = bench_fn()
return False, bench_time, configs_timings, best_config

from triton.compiler.compiler import make_backend

Expand All @@ -195,72 +208,113 @@ def check_disk_cache(self, tuning_key, configs, bench_fn):
with open(path, "r") as cached_configs:
timings = json.load(cached_configs)["configs_timings"]
timings = {Config(**config): timing for config, timing in timings}
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
self.configs_timings = timings
return True
best_config = builtins.min(timings, key=timings.get)
return True, None, timings, best_config

bench_fn()
configs_timings, bench_time, best_config = bench_fn()
cache.put(
json.dumps({
"key":
tuning_key,
"configs_timings":
[(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
"configs_timings": [(config.__dict__, timings)
for config, timings in (configs_timings or {}).items()
if not config.pre_hook],
}), file_name, binary=False)
return False
return False, bench_time, configs_timings, best_config

def _get_config_for_key(self, key, nargs, args, kwargs):
with self._cache_lock:
cached = self._cache.get(key)
if cached is not None:
return cached, True, None

future = self._cache_futures.get(key)
if future is None:
future = CacheFuture()
self._cache_futures[key] = future
runner = True
else:
runner = False

if not runner:
future.event.wait()
if future.error is not None:
raise future.error
return future.config, future.used_cached_result, future.bench_time

pruned_configs = self.prune_configs(kwargs, nargs)

def benchmark():
bench_start = time.time()
timings = {config: self._bench(nargs, *args, config=config, **kwargs) for config in pruned_configs}
bench_duration = time.time() - bench_start
best_config = builtins.min(timings, key=timings.get)
full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()}
self.pre_hook(full_nargs_local, reset_only=True)
return timings, bench_duration, best_config

used_cached_result = False
bench_time = None

try:
if self.cache_results:
used_cached_result, bench_time, configs_timings, best_config = self.check_disk_cache(
key, pruned_configs, benchmark)
else:
configs_timings, bench_time, best_config = benchmark()
used_cached_result = False

if best_config is not None:
with self._cache_lock:
self._cache[key] = best_config

future.config = best_config
future.used_cached_result = used_cached_result
future.bench_time = bench_time
return best_config, used_cached_result, bench_time
except BaseException as exc:
future.error = exc
raise
finally:
future.event.set()
with self._cache_lock:
self._cache_futures.pop(key, None)

def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
nargs = dict(zip(self.arg_names, args))
used_cached_result = True
bench_time = None
key = None
if len(self.configs) > 1:
all_args = {**self.nargs, **kwargs}
all_args = {**nargs, **kwargs}
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
key = [_args[key] for key in self.keys if key in _args]
key_values = [_args[key_name] for key_name in self.keys if key_name in _args]
for _, arg in _args.items():
if hasattr(arg, "dtype"):
key.append(str(arg.dtype))
key = tuple(key)
if key not in self.cache:
used_cached_result = False
pruned_configs = self.prune_configs(kwargs)

def benchmark():
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
self.pre_hook(full_nargs, reset_only=True)
self.configs_timings = timings

if self.cache_results:
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
else:
benchmark()

config = self.cache[key]
key_values.append(str(arg.dtype))
key = tuple(key_values)
config, used_cached_result, bench_time = self._get_config_for_key(key, nargs, args, kwargs)
else:
config = self.configs[0]
self.best_config = config
if knobs.autotuning.print and not used_cached_result:
if knobs.autotuning.print and key is not None and not used_cached_result:
bench_time_value = bench_time or 0.0
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
f"finished after {bench_time_value:.2f}s,\nbest config selected: {self.best_config};")
full_nargs = {**nargs, **kwargs, **config.all_kwargs()}
if config.pre_hook is not None:
full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
config.pre_hook(full_nargs)
ret = self.fn.run(
*args,
**kwargs,
**config.all_kwargs(),
)
self.nargs = None
return ret

def prune_configs(self, kwargs: Dict) -> List[Config]:
def prune_configs(self, kwargs: Dict, nargs: Dict) -> List[Config]:
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
pruned_configs = self.early_config_prune(self.configs, nargs, **kwargs)
if not pruned_configs:
raise AutotunerError(
"No valid autotuner configs after pruning. `early_config_prune` should return at least one config.")
Expand All @@ -275,7 +329,7 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**nargs,
**kwargs,
**config.all_kwargs(),
)
Expand All @@ -285,15 +339,14 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
return pruned_configs

def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
nargs = dict(zip(self.arg_names, args))
ret = []
for autotune_config in self.prune_configs(kwargs):
for autotune_config in self.prune_configs(kwargs, nargs):
ret.append(self.fn.warmup(
*args,
**kwargs,
**autotune_config.all_kwargs(),
))
self.nargs = None
return ret


Expand Down