Skip to content

Commit 6693ddd

Browse files
authored
Ignore autotune runs failed with PTXAS error (#5017)
1 parent ef319c8 commit 6693ddd

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

python/triton/runtime/autotuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Dict
88

99
from .jit import KernelInterface
10-
from .errors import OutOfResources
10+
from .errors import OutOfResources, PTXASError
1111
from .driver import driver
1212

1313

@@ -157,7 +157,7 @@ def kernel_call():
157157

158158
try:
159159
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
160-
except (OutOfResources, CompileTimeAssertionFailure):
160+
except (OutOfResources, CompileTimeAssertionFailure, PTXASError):
161161
return [float("inf"), float("inf"), float("inf")]
162162

163163
def run(self, *args, **kwargs):

python/triton/runtime/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,13 @@ def __str__(self) -> str:
2424
def __reduce__(self):
2525
# this is necessary to make CompilationError picklable
2626
return (type(self), (self.required, self.limit, self.name))
27+
28+
29+
class PTXASError(TritonError):
30+
31+
def __init__(self, error_message: Optional[str] = None):
32+
self.error_message = error_message
33+
34+
def __str__(self) -> str:
35+
error_message = self.error_message or ""
36+
return f"PTXAS error: {error_message}"

third_party/nvidia/backend/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from triton.backends.compiler import BaseBackend, GPUTarget
22
from triton._C.libtriton import ir, passes, llvm, nvidia
3+
from triton.runtime.errors import PTXASError
34

45
from dataclasses import dataclass
56
import functools
@@ -361,9 +362,9 @@ def make_cubin(src, metadata, opt, capability):
361362
else:
362363
error = f'`ptxas` failed with error code {e.returncode}'
363364

364-
raise RuntimeError(f'{error}\n'
365-
f'`ptxas` stderr:\n{log}\n'
366-
f'Repro command: {" ".join(ptxas_cmd)}\n')
365+
raise PTXASError(f"{error}\n"
366+
f"`ptxas` stderr:\n{log}\n"
367+
f'Repro command: {" ".join(ptxas_cmd)}\n')
367368

368369
with open(fbin, 'rb') as f:
369370
cubin = f.read()

0 commit comments

Comments
 (0)