diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 6ae50245ed3a..8495d2c4b759 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -465,7 +465,7 @@ def call_trtllm_fused_allreduce_norm( quant_out=quant_out, scale_out=scale_out, # in vllm we only support swizzled layout - layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED, scale_factor=scale_factor, ) else: