|
22 | 22 | is_valid_flashinfer_cutlass_fused_moe)
|
23 | 23 | from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
24 | 24 | WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
| 25 | +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( |
| 26 | + find_matched_target) |
25 | 27 | from vllm.model_executor.layers.quantization.utils import replace_parameter
|
26 | 28 | from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
27 | 29 | build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
@@ -65,12 +67,40 @@ def __init_(self, moe: FusedMoEConfig):
|
65 | 67 | @staticmethod
|
66 | 68 | def get_moe_method(
|
67 | 69 | quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
68 |
| - layer: torch.nn.Module, |
| 70 | + layer: torch.nn.Module |
69 | 71 | ) -> "CompressedTensorsMoEMethod":
|
70 | 72 | # TODO: @dsikka: refactor this to use schemes as other kernels
|
71 | 73 | # are supported + check if the layer is being ignored.
|
72 |
| - weight_quant = quant_config.target_scheme_map["Linear"].get("weights") |
73 |
| - input_quant = quant_config.target_scheme_map["Linear"].get( |
| 74 | + # Check if a using "Linear" to select scheems |
| 75 | + if "Linear" in quant_config.target_scheme_map: |
| 76 | + matched_target = "Linear" |
| 77 | + else: |
| 78 | + # May have instead defined the linear layers in the fused model |
| 79 | + |
| 80 | + fused_layers = [ |
| 81 | + "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*" |
| 82 | + ] |
| 83 | + current_scheme = None |
| 84 | + for fused_layer in fused_layers: |
| 85 | + # Check if one of the fused layers are defined in quant_config |
| 86 | + matched_target = find_matched_target( |
| 87 | + layer_name=fused_layer, |
| 88 | + module=layer, |
| 89 | + targets=quant_config.target_scheme_map.keys(), |
| 90 | + fused_mapping=quant_config.packed_modules_mapping) |
| 91 | + |
| 92 | + # Only valid if down_proj, gate_proj, and up_proj |
| 93 | + # are mapped to the same quant scheme in the quant_config |
| 94 | + if current_scheme is None: |
| 95 | + current_scheme = quant_config.target_scheme_map.get( |
| 96 | + matched_target) |
| 97 | + else: |
| 98 | + assert current_scheme == quant_config.target_scheme_map.get( |
| 99 | + matched_target) |
| 100 | + |
| 101 | + weight_quant = quant_config.target_scheme_map[matched_target].get( |
| 102 | + "weights") |
| 103 | + input_quant = quant_config.target_scheme_map[matched_target].get( |
74 | 104 | "input_activations")
|
75 | 105 |
|
76 | 106 | if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
|
0 commit comments