diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 8768ca028b1a..f9055d86b99d 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -10,7 +10,7 @@ from dataclasses import dataclass import builtins from .. import knobs -from ..runtime.jit import JITCallable +from ..runtime.jit import constexpr_function, JITCallable import inspect from .._C.libtriton import ir @@ -810,6 +810,7 @@ def __init__(self): pi32_t = pointer_type(int32) +@triton.constexpr_function def get_int_dtype(bitwidth: int, signed: bool) -> dtype: if bitwidth == 1: return int1