Skip to content

Commit dabff12

Browse files
[Bugfix][ROCm][Dynamo][DS 3.1][FP8] fix unsupported hasattr call when Dynamo tracing for ROCm device (#31149)
Signed-off-by: zejunchen-zejun <[email protected]>
1 parent 3bb9561 commit dabff12

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

vllm/_aiter_ops.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,31 @@ def _rocm_aiter_gemm_a8w8_fake(
380380
return Y
381381

382382

383+
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
384+
A: torch.Tensor,
385+
B: torch.Tensor,
386+
As: torch.Tensor,
387+
Bs: torch.Tensor,
388+
output_dtype: torch.dtype = torch.float16,
389+
) -> torch.Tensor:
390+
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
391+
392+
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
393+
394+
395+
def _rocm_aiter_triton_gemm_a8w8_blockscale_fake(
396+
A: torch.Tensor,
397+
B: torch.Tensor,
398+
As: torch.Tensor,
399+
Bs: torch.Tensor,
400+
output_dtype: torch.dtype = torch.float16,
401+
) -> torch.Tensor:
402+
m = A.shape[0]
403+
n = B.shape[0]
404+
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
405+
return Y
406+
407+
383408
def _rocm_aiter_gemm_a8w8_blockscale_impl(
384409
A: torch.Tensor,
385410
B: torch.Tensor,
@@ -964,6 +989,12 @@ def register_ops_once() -> None:
964989
dispatch_key=current_platform.dispatch_key,
965990
)
966991

992+
direct_register_custom_op(
993+
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
994+
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
995+
fake_impl=_rocm_aiter_triton_gemm_a8w8_blockscale_fake,
996+
)
997+
967998
direct_register_custom_op(
968999
op_name="rocm_aiter_gemm_a8w8_blockscale",
9691000
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
@@ -1102,6 +1133,19 @@ def gemm_a8w8(
11021133
) -> torch.Tensor:
11031134
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
11041135

1136+
@staticmethod
1137+
def triton_gemm_a8w8_blockscale(
1138+
A: torch.Tensor,
1139+
B: torch.Tensor,
1140+
As: torch.Tensor,
1141+
Bs: torch.Tensor,
1142+
block_size: list[int],
1143+
output_dtype: torch.dtype = torch.float16,
1144+
) -> torch.Tensor:
1145+
return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(
1146+
A, B, As, Bs, output_dtype
1147+
)
1148+
11051149
@staticmethod
11061150
def gemm_a8w8_blockscale(
11071151
A: torch.Tensor,
@@ -1373,19 +1417,6 @@ def triton_fp8_bmm(
13731417
config=config,
13741418
)
13751419

1376-
@staticmethod
1377-
def triton_gemm_a8w8_blockscale(
1378-
A: torch.Tensor,
1379-
B: torch.Tensor,
1380-
As: torch.Tensor,
1381-
Bs: torch.Tensor,
1382-
block_size: list[int],
1383-
output_dtype: torch.dtype = torch.float16,
1384-
) -> torch.Tensor:
1385-
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
1386-
1387-
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
1388-
13891420
@staticmethod
13901421
def group_fp8_quant(
13911422
input_2d: torch.Tensor,

0 commit comments

Comments
 (0)