diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index d9b972d6bfd2..2d237baf11c4 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -1,3 +1,6 @@ +import threading +import sys + import torch import triton @@ -450,6 +453,61 @@ def grid(meta): exception_out_of_resource) +def test_nogil_safety(device: str): + if getattr(sys, "_is_gil_enabled", lambda: True)(): + pytest.skip("Requires running with the GIL disabled (PYTHON_GIL=0)") + if not is_cuda(): + pytest.skip("CUDA backend is required for autotuner thread safety test") + + num_threads = 8 + N = 1024 + src = torch.randn(N, device=device) + dst = torch.empty_like(src) + + configs = [ + triton.Config(kwargs={'BLOCK_SIZE': 32}), + triton.Config(kwargs={'BLOCK_SIZE': 64}), + ] + + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + data = tl.load(src + offsets, mask=mask) + tl.store(dst + offsets, data, mask=mask) + + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) + + ready_barrier = threading.Barrier(num_threads + 1) + start_barrier = threading.Barrier(num_threads + 1) + results = [] + + def worker(): + ready_barrier.wait() + start_barrier.wait() + try: + _kernel[grid](dst, src, N) + results.append(None) + except Exception as exc: + results.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + + for thread in threads: + thread.start() + + ready_barrier.wait() + start_barrier.wait() + + for thread in threads: + thread.join() + + assert all(result is None for result in results) + assert len(_kernel.cache) == 1 + + def test_prune_all_configs(device): N = 1024 src = torch.randn(N, device=device) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index e12509f4f263..302943431615 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -5,6 +5,7 @@ import inspect import hashlib import json +import threading from functools import cached_property from typing import Dict, Tuple, List, Optional @@ -16,6 +17,16 @@ from triton._C.libtriton import get_cache_invalidating_env_vars +class CacheFuture: + + def __init__(self): + self.event = threading.Event() + self.config: Config | None = None + self.error: BaseException | None = None + self.used_cached_result: bool = True + self.bench_time: float | None = None + + class Autotuner(KernelInterface): def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None, @@ -34,7 +45,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr else: self.configs = configs self.keys = key - self.cache: Dict[Tuple, Config] = {} + self._cache: Dict[Tuple, Config] = {} + self._cache_lock = threading.RLock() + self._cache_futures: Dict[Tuple, CacheFuture] = {} self.arg_names = arg_names self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret) @@ -125,7 +138,7 @@ def do_bench(self): return driver.active.get_benchmarker() return self._do_bench - def _bench(self, *args, config, **meta): + def _bench(self, nargs, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure verbose = knobs.autotuning.print @@ -140,7 +153,7 @@ def _bench(self, *args, config, **meta): " Make sure that you don't re-define auto-tuned symbols.") # augment meta-parameters with tunable ones current = dict(meta, **config.all_kwargs()) - full_nargs = {**self.nargs, **current} + full_nargs = {**nargs, **current} def kernel_call(): if config.pre_hook: @@ -170,8 +183,8 @@ def kernel_call(): def check_disk_cache(self, tuning_key, configs, bench_fn): # We can't serialize prehooks, so just give up and run the benchmarks. if not tuning_key or any(cfg.pre_hook for cfg in configs): - bench_fn() - return False + configs_timings, bench_time, best_config = bench_fn() + return False, bench_time, configs_timings, best_config from triton.compiler.compiler import make_backend @@ -195,72 +208,113 @@ def check_disk_cache(self, tuning_key, configs, bench_fn): with open(path, "r") as cached_configs: timings = json.load(cached_configs)["configs_timings"] timings = {Config(**config): timing for config, timing in timings} - self.cache[tuning_key] = builtins.min(timings, key=timings.get) - self.configs_timings = timings - return True + best_config = builtins.min(timings, key=timings.get) + return True, None, timings, best_config - bench_fn() + configs_timings, bench_time, best_config = bench_fn() cache.put( json.dumps({ "key": tuning_key, - "configs_timings": - [(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook], + "configs_timings": [(config.__dict__, timings) + for config, timings in (configs_timings or {}).items() + if not config.pre_hook], }), file_name, binary=False) - return False + return False, bench_time, configs_timings, best_config + + def _get_config_for_key(self, key, nargs, args, kwargs): + with self._cache_lock: + cached = self._cache.get(key) + if cached is not None: + return cached, True, None + + future = self._cache_futures.get(key) + if future is None: + future = CacheFuture() + self._cache_futures[key] = future + runner = True + else: + runner = False + + if not runner: + future.event.wait() + if future.error is not None: + raise future.error + return future.config, future.used_cached_result, future.bench_time + + pruned_configs = self.prune_configs(kwargs, nargs) + + def benchmark(): + bench_start = time.time() + timings = {config: self._bench(nargs, *args, config=config, **kwargs) for config in pruned_configs} + bench_duration = time.time() - bench_start + best_config = builtins.min(timings, key=timings.get) + full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()} + self.pre_hook(full_nargs_local, reset_only=True) + return timings, bench_duration, best_config + + used_cached_result = False + bench_time = None + + try: + if self.cache_results: + used_cached_result, bench_time, configs_timings, best_config = self.check_disk_cache( + key, pruned_configs, benchmark) + else: + configs_timings, bench_time, best_config = benchmark() + used_cached_result = False + + if best_config is not None: + with self._cache_lock: + self._cache[key] = best_config + + future.config = best_config + future.used_cached_result = used_cached_result + future.bench_time = bench_time + return best_config, used_cached_result, bench_time + except BaseException as exc: + future.error = exc + raise + finally: + future.event.set() + with self._cache_lock: + self._cache_futures.pop(key, None) def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) + nargs = dict(zip(self.arg_names, args)) used_cached_result = True + bench_time = None + key = None if len(self.configs) > 1: - all_args = {**self.nargs, **kwargs} + all_args = {**nargs, **kwargs} _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} - key = [_args[key] for key in self.keys if key in _args] + key_values = [_args[key_name] for key_name in self.keys if key_name in _args] for _, arg in _args.items(): if hasattr(arg, "dtype"): - key.append(str(arg.dtype)) - key = tuple(key) - if key not in self.cache: - used_cached_result = False - pruned_configs = self.prune_configs(kwargs) - - def benchmark(): - bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} - self.pre_hook(full_nargs, reset_only=True) - self.configs_timings = timings - - if self.cache_results: - used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark) - else: - benchmark() - - config = self.cache[key] + key_values.append(str(arg.dtype)) + key = tuple(key_values) + config, used_cached_result, bench_time = self._get_config_for_key(key, nargs, args, kwargs) else: config = self.configs[0] self.best_config = config - if knobs.autotuning.print and not used_cached_result: + if knobs.autotuning.print and key is not None and not used_cached_result: + bench_time_value = bench_time or 0.0 print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n" - f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};") + f"finished after {bench_time_value:.2f}s,\nbest config selected: {self.best_config};") + full_nargs = {**nargs, **kwargs, **config.all_kwargs()} if config.pre_hook is not None: - full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} config.pre_hook(full_nargs) ret = self.fn.run( *args, **kwargs, **config.all_kwargs(), ) - self.nargs = None return ret - def prune_configs(self, kwargs: Dict) -> List[Config]: + def prune_configs(self, kwargs: Dict, nargs: Dict) -> List[Config]: pruned_configs = self.configs if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + pruned_configs = self.early_config_prune(self.configs, nargs, **kwargs) if not pruned_configs: raise AutotunerError( "No valid autotuner configs after pruning. `early_config_prune` should return at least one config.") @@ -275,7 +329,7 @@ def prune_configs(self, kwargs: Dict) -> List[Config]: if len(pruned_configs) > top_k: est_timing = { config: self.perf_model( - **self.nargs, + **nargs, **kwargs, **config.all_kwargs(), ) @@ -285,15 +339,14 @@ def prune_configs(self, kwargs: Dict) -> List[Config]: return pruned_configs def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) + nargs = dict(zip(self.arg_names, args)) ret = [] - for autotune_config in self.prune_configs(kwargs): + for autotune_config in self.prune_configs(kwargs, nargs): ret.append(self.fn.warmup( *args, **kwargs, **autotune_config.all_kwargs(), )) - self.nargs = None return ret