@@ -43,8 +43,9 @@ def _should_use_flashinfer_mxfp4_bf16():
43
43
or current_platform .is_device_capability (90 ) and has_flashinfer ()
44
44
and not envs .is_set ("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8" )):
45
45
logger .info_once (
46
- "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell and Hopper. "
47
- "For faster performance, consider setting "
46
+ "Enabling FlashInfer MXFP4 BF16 backend by "
47
+ "default for Blackwell and Hopper. "
48
+ "For faster performance on Blackwell, consider setting "
48
49
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
49
50
"though this may impact accuracy." )
50
51
return True
@@ -392,12 +393,24 @@ def swap_every_two_rows(x, axis=-1):
392
393
requires_grad = False )
393
394
elif _should_use_flashinfer_mxfp4_bf16 (
394
395
) and current_platform .is_device_capability (90 ):
395
- assert layer .w13_weight .dtype == torch .uint8 , f"layer.w13_weight.dtype: { layer .w13_weight .dtype } , expected: { torch .uint8 } "
396
- assert layer .w2_weight .dtype == torch .uint8 , f"layer.w2_weight.dtype: { layer .w2_weight .dtype } , expected: { torch .uint8 } "
397
- assert layer .w13_weight_scale .dtype == torch .uint8 , f"layer.w13_weight_scale.dtype: { layer .w13_weight_scale .dtype } , expected: { torch .uint8 } "
398
- assert layer .w2_weight_scale .dtype == torch .uint8 , f"layer.w2_weight_scale.dtype: { layer .w2_weight_scale .dtype } , expected: { torch .uint8 } "
399
- assert layer .w13_bias .dtype == torch .bfloat16 , f"layer.w13_bias.dtype: { layer .w13_bias .dtype } , expected: { torch .bfloat16 } "
400
- assert layer .w2_bias .dtype == torch .bfloat16 , f"layer.w2_bias.dtype: { layer .w2_bias .dtype } , expected: { torch .bfloat16 } "
396
+ assert layer .w13_weight .dtype == torch .uint8 , (
397
+ f"layer.w13_weight.dtype: { layer .w13_weight .dtype } , "
398
+ f"expected: { torch .uint8 } " )
399
+ assert layer .w2_weight .dtype == torch .uint8 , (
400
+ f"layer.w2_weight.dtype: { layer .w2_weight .dtype } , "
401
+ f"expected: { torch .uint8 } " )
402
+ assert layer .w13_weight_scale .dtype == torch .uint8 , (
403
+ f"layer.w13_weight_scale.dtype: { layer .w13_weight_scale .dtype } , " # noqa: E501
404
+ f"expected: { torch .uint8 } " )
405
+ assert layer .w2_weight_scale .dtype == torch .uint8 , (
406
+ f"layer.w2_weight_scale.dtype: { layer .w2_weight_scale .dtype } , " # noqa: E501
407
+ f"expected: { torch .uint8 } " )
408
+ assert layer .w13_bias .dtype == torch .bfloat16 , (
409
+ f"layer.w13_bias.dtype: { layer .w13_bias .dtype } , "
410
+ f"expected: { torch .bfloat16 } " )
411
+ assert layer .w2_bias .dtype == torch .bfloat16 , (
412
+ f"layer.w2_bias.dtype: { layer .w2_bias .dtype } , "
413
+ f"expected: { torch .bfloat16 } " )
401
414
402
415
layer .gemm1_alpha = Parameter (torch .tensor (
403
416
[1.702 ] * self .num_experts , dtype = torch .float32 ).cuda (),
@@ -435,7 +448,7 @@ def swap_every_two_rows(x, axis=-1):
435
448
and layer .w2_bias .shape [0 ] == self .num_experts
436
449
and layer .w2_bias .shape [1 ] == self .hidden_size )
437
450
438
- # De-interleave weights, scales, and biases for gate and up projections
451
+ # De-interleave weights, scales, and biases
439
452
w13_weight_data = layer .w13_weight .data
440
453
gate_w , up_w = w13_weight_data [:, ::2 , :], w13_weight_data [:,
441
454
1 ::2 , :]
0 commit comments