Skip to content

Commit fcba05c

Browse files
authored
[Bug] Fix Layer weight_block_size Assertion Issue (#24674)
Signed-off-by: yewentao256 <[email protected]>
1 parent 7a30fa8 commit fcba05c

File tree

1 file changed

+3
-3
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+3
-3
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
450450
# Activations not quantized for marlin.
451451
del layer.input_scale
452452

453-
# On B200, if E8M0 for DeepGemm is used, we need to
453+
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
454454
# requantize the weight and input to the specific scale
455455
# at the same time.
456-
if is_deep_gemm_e8m0_used():
456+
if is_deep_gemm_e8m0_used() and self.block_quant:
457457
assert layer.weight_block_size is not None
458458
block_sz = tuple(layer.weight_block_size)
459459
requant_weight_ue8m0_inplace(
@@ -905,7 +905,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
905905
del layer.w13_input_scale
906906
del layer.w2_input_scale
907907

908-
if is_deep_gemm_e8m0_used():
908+
if is_deep_gemm_e8m0_used() and self.block_quant:
909909
assert layer.weight_block_size is not None
910910
# Re-quantise the expert weights so their scales are UE8M0.
911911
block_sz = tuple(layer.weight_block_size)

0 commit comments

Comments
 (0)