Skip to content

Commit 62c2c0a

Browse files
committed
more precomiit
Signed-off-by: Duncan Moss <[email protected]>
1 parent 97dd50c commit 62c2c0a

File tree

1 file changed

+22
-9
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+22
-9
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def _should_use_flashinfer_mxfp4_bf16():
4343
or current_platform.is_device_capability(90) and has_flashinfer()
4444
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
4545
logger.info_once(
46-
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell and Hopper. "
47-
"For faster performance, consider setting "
46+
"Enabling FlashInfer MXFP4 BF16 backend by "
47+
"default for Blackwell and Hopper. "
48+
"For faster performance on Blackwell, consider setting "
4849
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
4950
"though this may impact accuracy.")
5051
return True
@@ -392,12 +393,24 @@ def swap_every_two_rows(x, axis=-1):
392393
requires_grad=False)
393394
elif _should_use_flashinfer_mxfp4_bf16(
394395
) and current_platform.is_device_capability(90):
395-
assert layer.w13_weight.dtype == torch.uint8, f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, expected: {torch.uint8}"
396-
assert layer.w2_weight.dtype == torch.uint8, f"layer.w2_weight.dtype: {layer.w2_weight.dtype}, expected: {torch.uint8}"
397-
assert layer.w13_weight_scale.dtype == torch.uint8, f"layer.w13_weight_scale.dtype: {layer.w13_weight_scale.dtype}, expected: {torch.uint8}"
398-
assert layer.w2_weight_scale.dtype == torch.uint8, f"layer.w2_weight_scale.dtype: {layer.w2_weight_scale.dtype}, expected: {torch.uint8}"
399-
assert layer.w13_bias.dtype == torch.bfloat16, f"layer.w13_bias.dtype: {layer.w13_bias.dtype}, expected: {torch.bfloat16}"
400-
assert layer.w2_bias.dtype == torch.bfloat16, f"layer.w2_bias.dtype: {layer.w2_bias.dtype}, expected: {torch.bfloat16}"
396+
assert layer.w13_weight.dtype == torch.uint8, (
397+
f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, "
398+
f"expected: {torch.uint8}")
399+
assert layer.w2_weight.dtype == torch.uint8, (
400+
f"layer.w2_weight.dtype: {layer.w2_weight.dtype}, "
401+
f"expected: {torch.uint8}")
402+
assert layer.w13_weight_scale.dtype == torch.uint8, (
403+
f"layer.w13_weight_scale.dtype: {layer.w13_weight_scale.dtype}, " # noqa: E501
404+
f"expected: {torch.uint8}")
405+
assert layer.w2_weight_scale.dtype == torch.uint8, (
406+
f"layer.w2_weight_scale.dtype: {layer.w2_weight_scale.dtype}, " # noqa: E501
407+
f"expected: {torch.uint8}")
408+
assert layer.w13_bias.dtype == torch.bfloat16, (
409+
f"layer.w13_bias.dtype: {layer.w13_bias.dtype}, "
410+
f"expected: {torch.bfloat16}")
411+
assert layer.w2_bias.dtype == torch.bfloat16, (
412+
f"layer.w2_bias.dtype: {layer.w2_bias.dtype}, "
413+
f"expected: {torch.bfloat16}")
401414

402415
layer.gemm1_alpha = Parameter(torch.tensor(
403416
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
@@ -435,7 +448,7 @@ def swap_every_two_rows(x, axis=-1):
435448
and layer.w2_bias.shape[0] == self.num_experts
436449
and layer.w2_bias.shape[1] == self.hidden_size)
437450

438-
# De-interleave weights, scales, and biases for gate and up projections
451+
# De-interleave weights, scales, and biases
439452
w13_weight_data = layer.w13_weight.data
440453
gate_w, up_w = w13_weight_data[:, ::2, :], w13_weight_data[:,
441454
1::2, :]

0 commit comments

Comments
 (0)