@@ -380,6 +380,31 @@ def _rocm_aiter_gemm_a8w8_fake(
380380 return Y
381381
382382
383+ def _rocm_aiter_triton_gemm_a8w8_blockscale_impl (
384+ A : torch .Tensor ,
385+ B : torch .Tensor ,
386+ As : torch .Tensor ,
387+ Bs : torch .Tensor ,
388+ output_dtype : torch .dtype = torch .float16 ,
389+ ) -> torch .Tensor :
390+ from aiter .ops .triton .gemm_a8w8_blockscale import gemm_a8w8_blockscale
391+
392+ return gemm_a8w8_blockscale (A , B , As , Bs , dtype = output_dtype )
393+
394+
395+ def _rocm_aiter_triton_gemm_a8w8_blockscale_fake (
396+ A : torch .Tensor ,
397+ B : torch .Tensor ,
398+ As : torch .Tensor ,
399+ Bs : torch .Tensor ,
400+ output_dtype : torch .dtype = torch .float16 ,
401+ ) -> torch .Tensor :
402+ m = A .shape [0 ]
403+ n = B .shape [0 ]
404+ Y = torch .empty (m , n , dtype = output_dtype , device = A .device )
405+ return Y
406+
407+
383408def _rocm_aiter_gemm_a8w8_blockscale_impl (
384409 A : torch .Tensor ,
385410 B : torch .Tensor ,
@@ -964,6 +989,12 @@ def register_ops_once() -> None:
964989 dispatch_key = current_platform .dispatch_key ,
965990 )
966991
992+ direct_register_custom_op (
993+ op_name = "rocm_aiter_triton_gemm_a8w8_blockscale" ,
994+ op_func = _rocm_aiter_triton_gemm_a8w8_blockscale_impl ,
995+ fake_impl = _rocm_aiter_triton_gemm_a8w8_blockscale_fake ,
996+ )
997+
967998 direct_register_custom_op (
968999 op_name = "rocm_aiter_gemm_a8w8_blockscale" ,
9691000 op_func = _rocm_aiter_gemm_a8w8_blockscale_impl ,
@@ -1102,6 +1133,19 @@ def gemm_a8w8(
11021133 ) -> torch .Tensor :
11031134 return torch .ops .vllm .rocm_aiter_gemm_a8w8 (A , B , As , Bs , bias , output_dtype )
11041135
1136+ @staticmethod
1137+ def triton_gemm_a8w8_blockscale (
1138+ A : torch .Tensor ,
1139+ B : torch .Tensor ,
1140+ As : torch .Tensor ,
1141+ Bs : torch .Tensor ,
1142+ block_size : list [int ],
1143+ output_dtype : torch .dtype = torch .float16 ,
1144+ ) -> torch .Tensor :
1145+ return torch .ops .vllm .rocm_aiter_triton_gemm_a8w8_blockscale (
1146+ A , B , As , Bs , output_dtype
1147+ )
1148+
11051149 @staticmethod
11061150 def gemm_a8w8_blockscale (
11071151 A : torch .Tensor ,
@@ -1373,19 +1417,6 @@ def triton_fp8_bmm(
13731417 config = config ,
13741418 )
13751419
1376- @staticmethod
1377- def triton_gemm_a8w8_blockscale (
1378- A : torch .Tensor ,
1379- B : torch .Tensor ,
1380- As : torch .Tensor ,
1381- Bs : torch .Tensor ,
1382- block_size : list [int ],
1383- output_dtype : torch .dtype = torch .float16 ,
1384- ) -> torch .Tensor :
1385- from aiter .ops .triton .gemm_a8w8_blockscale import gemm_a8w8_blockscale
1386-
1387- return gemm_a8w8_blockscale (A , B , As , Bs , dtype = output_dtype )
1388-
13891420 @staticmethod
13901421 def group_fp8_quant (
13911422 input_2d : torch .Tensor ,
0 commit comments