Skip to content

Commit cc7ae5e

Browse files
authored
[BugFix][AMD][Quantization] Fix torch.compile issue where wvSplitKQ not being called when it should when using quantized FP8 model (#22281)
Signed-off-by: Randall Smith <[email protected]>
1 parent 0313cf8 commit cc7ae5e

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1414
GroupShape)
1515
from vllm.platforms import current_platform
16+
from vllm.utils import direct_register_custom_op
1617

1718
# Input scaling factors are no longer optional in _scaled_mm starting
1819
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@@ -156,13 +157,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
156157
return output.view(*output_shape)
157158

158159

159-
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
160-
weight: torch.Tensor,
161-
out_dtype: torch.dtype,
162-
scale_a: torch.Tensor,
163-
scale_b: torch.Tensor, bias: torch.Tensor,
164-
input_2d: torch.Tensor,
165-
output_shape: list) -> torch.Tensor:
160+
def rocm_per_tensor_w8a8_scaled_mm_impl(
161+
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
162+
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
163+
input_2d: torch.Tensor) -> torch.Tensor:
166164
from vllm.platforms.rocm import on_mi3xx
167165
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
168166
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
@@ -175,10 +173,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
175173
scale_a=scale_a,
176174
scale_b=scale_b,
177175
bias=bias)
176+
return output
177+
178+
179+
def rocm_per_tensor_w8a8_scaled_mm_fake(
180+
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
181+
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
182+
input_2d: torch.Tensor) -> torch.Tensor:
183+
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
184+
dtype=out_dtype)
178185

186+
187+
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
188+
weight: torch.Tensor,
189+
out_dtype: torch.dtype,
190+
scale_a: torch.Tensor,
191+
scale_b: torch.Tensor, bias: torch.Tensor,
192+
input_2d: torch.Tensor,
193+
output_shape: list) -> torch.Tensor:
194+
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
195+
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
179196
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
180197

181198

199+
direct_register_custom_op(
200+
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
201+
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
202+
mutates_args=[],
203+
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
204+
dispatch_key=current_platform.dispatch_key,
205+
)
206+
207+
182208
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
183209
weight: torch.Tensor,
184210
out_dtype: torch.dtype,

0 commit comments

Comments
 (0)