Skip to content

Commit 072be89

Browse files
dsikkazhewenl
authored andcommitted
[Quantization] Expand compressed-tensors MoE matching logic to support NFP4 + FP8 MoEs (vllm-project#22674)
Signed-off-by: Dipika Sikka <[email protected]> Signed-off-by: Dipika <[email protected]>
1 parent 5da9022 commit 072be89

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,10 @@ def _get_scheme_from_parts(
425425
weight_quant: BaseModel,
426426
input_quant: BaseModel,
427427
format: Optional[str] = None) -> "CompressedTensorsScheme":
428+
429+
# use the per-layer format if defined, otherwise, use global format
430+
format = format if format is not None else self.quant_format
431+
428432
# Detect If Mixed Precision
429433
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
430434
return CompressedTensorsW4A16Fp4()
@@ -437,14 +441,14 @@ def _get_scheme_from_parts(
437441
actorder=weight_quant.actorder)
438442

439443
if self._is_wNa16_group_channel(weight_quant, input_quant):
440-
if (self.quant_format == CompressionFormat.marlin_24.value
444+
if (format == CompressionFormat.marlin_24.value
441445
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
442446
assert weight_quant.symmetric
443447
return CompressedTensorsW4A16Sparse24(
444448
strategy=weight_quant.strategy,
445449
num_bits=weight_quant.num_bits,
446450
group_size=weight_quant.group_size)
447-
if (self.quant_format == CompressionFormat.pack_quantized.value
451+
if (format == CompressionFormat.pack_quantized.value
448452
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
449453
return CompressedTensorsWNA16(
450454
num_bits=weight_quant.num_bits,
@@ -453,10 +457,7 @@ def _get_scheme_from_parts(
453457
group_size=weight_quant.group_size,
454458
actorder=weight_quant.actorder)
455459

456-
act_quant_format = is_activation_quantization_format(
457-
format
458-
) if format is not None else is_activation_quantization_format(
459-
self.quant_format)
460+
act_quant_format = is_activation_quantization_format(format)
460461
if act_quant_format:
461462
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
462463
if cutlass_fp4_supported(

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
is_valid_flashinfer_cutlass_fused_moe)
2323
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
2424
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
25+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
26+
find_matched_target)
2527
from vllm.model_executor.layers.quantization.utils import replace_parameter
2628
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
2729
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
@@ -65,12 +67,40 @@ def __init_(self, moe: FusedMoEConfig):
6567
@staticmethod
6668
def get_moe_method(
6769
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
68-
layer: torch.nn.Module,
70+
layer: torch.nn.Module
6971
) -> "CompressedTensorsMoEMethod":
7072
# TODO: @dsikka: refactor this to use schemes as other kernels
7173
# are supported + check if the layer is being ignored.
72-
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
73-
input_quant = quant_config.target_scheme_map["Linear"].get(
74+
# Check if a using "Linear" to select scheems
75+
if "Linear" in quant_config.target_scheme_map:
76+
matched_target = "Linear"
77+
else:
78+
# May have instead defined the linear layers in the fused model
79+
80+
fused_layers = [
81+
"re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"
82+
]
83+
current_scheme = None
84+
for fused_layer in fused_layers:
85+
# Check if one of the fused layers are defined in quant_config
86+
matched_target = find_matched_target(
87+
layer_name=fused_layer,
88+
module=layer,
89+
targets=quant_config.target_scheme_map.keys(),
90+
fused_mapping=quant_config.packed_modules_mapping)
91+
92+
# Only valid if down_proj, gate_proj, and up_proj
93+
# are mapped to the same quant scheme in the quant_config
94+
if current_scheme is None:
95+
current_scheme = quant_config.target_scheme_map.get(
96+
matched_target)
97+
else:
98+
assert current_scheme == quant_config.target_scheme_map.get(
99+
matched_target)
100+
101+
weight_quant = quant_config.target_scheme_map[matched_target].get(
102+
"weights")
103+
input_quant = quant_config.target_scheme_map[matched_target].get(
74104
"input_activations")
75105

76106
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):

0 commit comments

Comments
 (0)