Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tilelang/language/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tvm.tir import Var, PrimExpr
from tvm.script.ir_builder.tir import buffer, handle, match_buffer
from tilelang.utils import deprecated
from tilelang.utils.tensor import torch_dtype_to_str


class BufferProxy:
Expand All @@ -28,6 +29,13 @@ def __call__(
buffer_type="",
axis_separators=None,
) -> tir.Buffer:
# If dtype is a torch.dtype, convert to short string
try:
import torch # type: ignore
if isinstance(dtype, torch.dtype): # type: ignore[attr-defined]
dtype = torch_dtype_to_str(dtype)
except Exception:
pass
return buffer(
shape,
dtype=dtype,
Expand Down Expand Up @@ -65,6 +73,12 @@ def from_ptr(self,
Returns:
A buffer created from the given parameters
"""
try:
import torch # type: ignore
if isinstance(dtype, torch.dtype): # type: ignore[attr-defined]
dtype = torch_dtype_to_str(dtype)
except Exception:
pass
return match_buffer(pointer_var, shape, dtype=dtype, strides=strides)


Expand Down Expand Up @@ -96,6 +110,13 @@ def __call__(
scope = scope or self.default_scope
align = align or self.default_align
offset_factor = offset_factor or self.default_offset_factor
# Convert torch.dtype to string if needed
try:
import torch # type: ignore
if isinstance(dtype, torch.dtype): # type: ignore[attr-defined]
dtype = torch_dtype_to_str(dtype)
except Exception:
pass

return buffer(
shape,
Expand Down Expand Up @@ -132,6 +153,12 @@ def from_ptr(self,
Returns:
A buffer created from the given parameters
"""
try:
import torch # type: ignore
if isinstance(dtype, torch.dtype): # type: ignore[attr-defined]
dtype = torch_dtype_to_str(dtype)
except Exception:
pass
return match_buffer(pointer_var, shape, dtype=dtype, strides=strides)


Expand Down Expand Up @@ -302,4 +329,10 @@ def make_tensor(ptr: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
try:
import torch # type: ignore
if isinstance(dtype, torch.dtype): # type: ignore[attr-defined]
dtype = torch_dtype_to_str(dtype)
except Exception:
pass
return Tensor.from_ptr(ptr, shape, dtype, strides)
10 changes: 8 additions & 2 deletions tilelang/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def map_torch_type(intype: str) -> torch.dtype:
return getattr(torch, intype)


def torch_dtype_to_str(dt) -> str:

if isinstance(dt, torch.dtype):
return str(dt).split('.')[-1]
return str(dt)


def adapt_torch2tvm(arg):
float8_dtype_map = {
torch.float8_e4m3fn: "float8_e4m3",
Expand Down Expand Up @@ -154,8 +161,7 @@ def _compare_attributes(

def raise_mismatch_error(attribute_name: str, actual_value, expected_value):
raise AssertionError(
f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}."
)
f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.")

if actual.shape != expected.shape:
raise_mismatch_error("shape", actual.shape, expected.shape)
Expand Down