diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 4f854ba27..6f6ea8857 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -8,6 +8,7 @@ from tvm.tir import Var, PrimExpr from tvm.script.ir_builder.tir import buffer, handle, match_buffer from tilelang.utils import deprecated +from tilelang.utils.tensor import torch_dtype_to_str class BufferProxy: @@ -28,6 +29,13 @@ def __call__( buffer_type="", axis_separators=None, ) -> tir.Buffer: + # If dtype is a torch.dtype, convert to short string + try: + import torch # type: ignore + if isinstance(dtype, torch.dtype): # type: ignore[attr-defined] + dtype = torch_dtype_to_str(dtype) + except Exception: + pass return buffer( shape, dtype=dtype, @@ -65,6 +73,12 @@ def from_ptr(self, Returns: A buffer created from the given parameters """ + try: + import torch # type: ignore + if isinstance(dtype, torch.dtype): # type: ignore[attr-defined] + dtype = torch_dtype_to_str(dtype) + except Exception: + pass return match_buffer(pointer_var, shape, dtype=dtype, strides=strides) @@ -96,6 +110,13 @@ def __call__( scope = scope or self.default_scope align = align or self.default_align offset_factor = offset_factor or self.default_offset_factor + # Convert torch.dtype to string if needed + try: + import torch # type: ignore + if isinstance(dtype, torch.dtype): # type: ignore[attr-defined] + dtype = torch_dtype_to_str(dtype) + except Exception: + pass return buffer( shape, @@ -132,6 +153,12 @@ def from_ptr(self, Returns: A buffer created from the given parameters """ + try: + import torch # type: ignore + if isinstance(dtype, torch.dtype): # type: ignore[attr-defined] + dtype = torch_dtype_to_str(dtype) + except Exception: + pass return match_buffer(pointer_var, shape, dtype=dtype, strides=strides) @@ -302,4 +329,10 @@ def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: + try: + import torch # type: ignore + if isinstance(dtype, torch.dtype): # type: ignore[attr-defined] + dtype = torch_dtype_to_str(dtype) + except Exception: + pass return Tensor.from_ptr(ptr, shape, dtype, strides) diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index 07a34cc44..f8bc5ccaa 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -38,6 +38,13 @@ def map_torch_type(intype: str) -> torch.dtype: return getattr(torch, intype) +def torch_dtype_to_str(dt) -> str: + + if isinstance(dt, torch.dtype): + return str(dt).split('.')[-1] + return str(dt) + + def adapt_torch2tvm(arg): float8_dtype_map = { torch.float8_e4m3fn: "float8_e4m3", @@ -154,8 +161,7 @@ def _compare_attributes( def raise_mismatch_error(attribute_name: str, actual_value, expected_value): raise AssertionError( - f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}." - ) + f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.") if actual.shape != expected.shape: raise_mismatch_error("shape", actual.shape, expected.shape)