1010import torch .nn .functional as F
1111
1212from vllm import envs
13+ from vllm ._aiter_ops import rocm_aiter_ops
1314from vllm .logger import init_logger
1415from vllm .model_executor .layers .quantization .utils .mxfp4_utils import (
1516 dequant_mxfp4 ,
@@ -49,7 +50,10 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
4950
5051try :
5152 from aiter .ops .shuffle import shuffle_weight
52- from aiter .ops .triton .gemm_afp4wfp4 import gemm_afp4wfp4
53+ from aiter .ops .triton .gemm_afp4wfp4 import (
54+ gemm_afp4wfp4 ,
55+ gemm_afp4wfp4_preshuffled_weight_scales ,
56+ )
5357 from aiter .ops .triton .quant import dynamic_mxfp4_quant
5458
5559 from vllm .utils .torch_utils import direct_register_custom_op
@@ -66,23 +70,56 @@ def gemm_with_dynamic_quant(
6670 x_scales : torch .Tensor | None = None ,
6771 ) -> torch .Tensor :
6872 M = x .shape [0 ]
73+ N = weight .shape [0 ]
74+ K = weight .shape [1 ]
6975 if rocm_use_aiter_fp4_asm_gemm :
70- if x_scales is None :
71- # use hip quant kernel for performance
72- x_q , x_s = per_1x32_f4_quant_hip (x , shuffle = True )
76+ if M <= 64 and rocm_aiter_ops .is_triton_gemm_afp4wfp4_presh_ws_tuned (N , K ):
77+ if x_scales is None :
78+ # use hip quant kernel for performance
79+ if M >= 32 :
80+ x_q , x_s = per_1x32_f4_quant_hip (x , shuffle = True )
81+ else :
82+ x_q , x_s = per_1x32_f4_quant_hip (x , shuffle = False )
83+ else :
84+ x_q = x
85+ x_s = x_scales
86+
87+ if M >= 32 :
88+ x_s = x_s .view (torch .uint8 ).view (x_s .shape [0 ] // 32 , - 1 )
89+ else :
90+ x_s = x_s [:M , ...].view (torch .uint8 )
91+
92+ y = torch .empty (M , N , device = x_q .device , dtype = out_dtype )
93+ gemm_afp4wfp4_preshuffled_weight_scales (
94+ x_q .view (torch .uint8 ),
95+ weight .view (torch .uint8 ).view (weight .shape [0 ] // 16 , - 1 ),
96+ x_s ,
97+ weight_scale .view (torch .uint8 ).view (
98+ weight_scale .shape [0 ] // 32 , - 1
99+ ),
100+ out_dtype ,
101+ y ,
102+ )
73103 else :
74- x_q = x
75- x_s = x_scales
76-
77- # 32 alignment is enough for dim0 padding of output for
78- # gemm_a4w4 kernel
79- y = torch .empty (
80- (M + 31 ) // 32 * 32 , weight .shape [0 ], device = x_q .device , dtype = out_dtype
81- )
104+ if x_scales is None :
105+ # use hip quant kernel for performance
106+ x_q , x_s = per_1x32_f4_quant_hip (x , shuffle = True )
107+ else :
108+ x_q = x
109+ x_s = x_scales
110+
111+ # 32 alignment is enough for dim0 padding of output for
112+ # gemm_a4w4 kernel
113+ y = torch .empty (
114+ (M + 31 ) // 32 * 32 ,
115+ weight .shape [0 ],
116+ device = x_q .device ,
117+ dtype = out_dtype ,
118+ )
82119
83- gemm_a4w4 (
84- x_q , weight , x_s , weight_scale .view (x_s .dtype ), y , bpreshuffle = True
85- )
120+ gemm_a4w4 (
121+ x_q , weight , x_s , weight_scale .view (x_s .dtype ), y , bpreshuffle = True
122+ )
86123 return y [:M ]
87124 else :
88125 if x_scales is None :
0 commit comments