Skip to content

Commit b7f1f49

Browse files
maleksan85Aleksandr Malyshev
andauthored
Upstream triton fp4 weight preshuffle (#28888)
Signed-off-by: Aleksandr Malyshev <[email protected]> Co-authored-by: Aleksandr Malyshev <[email protected]>
1 parent 30b44a1 commit b7f1f49

File tree

2 files changed

+77
-15
lines changed

2 files changed

+77
-15
lines changed

vllm/_aiter_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,31 @@ def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
948948
(8192, 32768),
949949
]
950950

951+
@staticmethod
952+
def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
953+
return (n, k) in [
954+
(8192, 4096),
955+
(1280, 8192),
956+
(16384, 53248),
957+
(106496, 16384),
958+
(57344, 8192),
959+
(8192, 2048),
960+
(2560, 8192),
961+
(10240, 8192),
962+
(16384, 16384),
963+
(8192, 28672),
964+
(28672, 8192),
965+
(18432, 16384),
966+
(8192, 1024),
967+
(7168, 8192),
968+
(5120, 8192),
969+
(8192, 8192),
970+
(8192, 7168),
971+
(14336, 8192),
972+
(8192, 14336),
973+
(8192, 3584),
974+
]
975+
951976
@staticmethod
952977
def shuffle_weight(
953978
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)

vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn.functional as F
1111

1212
from vllm import envs
13+
from vllm._aiter_ops import rocm_aiter_ops
1314
from vllm.logger import init_logger
1415
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
1516
dequant_mxfp4,
@@ -49,7 +50,10 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
4950

5051
try:
5152
from aiter.ops.shuffle import shuffle_weight
52-
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
53+
from aiter.ops.triton.gemm_afp4wfp4 import (
54+
gemm_afp4wfp4,
55+
gemm_afp4wfp4_preshuffled_weight_scales,
56+
)
5357
from aiter.ops.triton.quant import dynamic_mxfp4_quant
5458

5559
from vllm.utils.torch_utils import direct_register_custom_op
@@ -66,23 +70,56 @@ def gemm_with_dynamic_quant(
6670
x_scales: torch.Tensor | None = None,
6771
) -> torch.Tensor:
6872
M = x.shape[0]
73+
N = weight.shape[0]
74+
K = weight.shape[1]
6975
if rocm_use_aiter_fp4_asm_gemm:
70-
if x_scales is None:
71-
# use hip quant kernel for performance
72-
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
76+
if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
77+
if x_scales is None:
78+
# use hip quant kernel for performance
79+
if M >= 32:
80+
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
81+
else:
82+
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
83+
else:
84+
x_q = x
85+
x_s = x_scales
86+
87+
if M >= 32:
88+
x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
89+
else:
90+
x_s = x_s[:M, ...].view(torch.uint8)
91+
92+
y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
93+
gemm_afp4wfp4_preshuffled_weight_scales(
94+
x_q.view(torch.uint8),
95+
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
96+
x_s,
97+
weight_scale.view(torch.uint8).view(
98+
weight_scale.shape[0] // 32, -1
99+
),
100+
out_dtype,
101+
y,
102+
)
73103
else:
74-
x_q = x
75-
x_s = x_scales
76-
77-
# 32 alignment is enough for dim0 padding of output for
78-
# gemm_a4w4 kernel
79-
y = torch.empty(
80-
(M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype
81-
)
104+
if x_scales is None:
105+
# use hip quant kernel for performance
106+
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
107+
else:
108+
x_q = x
109+
x_s = x_scales
110+
111+
# 32 alignment is enough for dim0 padding of output for
112+
# gemm_a4w4 kernel
113+
y = torch.empty(
114+
(M + 31) // 32 * 32,
115+
weight.shape[0],
116+
device=x_q.device,
117+
dtype=out_dtype,
118+
)
82119

83-
gemm_a4w4(
84-
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
85-
)
120+
gemm_a4w4(
121+
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
122+
)
86123
return y[:M]
87124
else:
88125
if x_scales is None:

0 commit comments

Comments
 (0)