From 7da895765e32c926de78cbbad53848aeb1380361 Mon Sep 17 00:00:00 2001 From: Matrix_K Date: Fri, 6 Mar 2026 13:30:30 +0000 Subject: [PATCH 1/3] refactor(modelslim_config.py): Optimize the prefix mapping logic of the quantization layer name Signed-off-by: Matrix_K --- vllm_ascend/quantization/modelslim_config.py | 105 +++++++++++++------ 1 file changed, 73 insertions(+), 32 deletions(-) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 3678c6f3a9f..38ba5580af0 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -50,32 +50,12 @@ # key: model_type # value: orig_to_new_prefix QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = { - "qwen3_vl_moe": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, - "qwen3_vl_text": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, - "kimi_k25": { - "mm_projector.linear_1": "mm_projector.proj.0", - "mm_projector.linear_2": "mm_projector.proj.2", - }, - "qwen3_omni_moe": { - "language_model.lm_head.": "thinker.lm_head.", - "language_model.model.": "thinker.model.", - "visual.": "thinker.visual.", - }, - "qwen2_5_omni": { - "language_model.lm_head.": "thinker.lm_head.", - "language_model.model.": "thinker.model.", - "visual.": "thinker.visual.", - }, - "qwen2_5_omni_text": { - "language_model.": "thinker.", + "qwen3_omni_moe_thinker": { + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", }, } @@ -381,6 +361,7 @@ def __init__(self, quant_config: dict[str, Any] | None = None): new_k = k.replace("weight_packed", "weight") extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) + self.model_type = None def __repr__(self) -> str: return "AscendModelSlimConfig:\n" + super().__repr__() @@ -418,12 +399,75 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None: return ASCEND_QUANTIZATION_METHOD return None + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + """Apply the vLLM model-specific mapper to this quantization config. + + This method is called by vLLM to apply the model-specific weight mapper + to the quantization configuration. It creates a reverse mapper to convert + vLLM prefixes back to HF format for looking up keys in quant_config.json. + + Args: + hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM + that contains model-specific prefix mappings (HF to vLLM). + """ + # Check if we already have a valid vllm_to_hf_mapper for this hf_to_vllm_mapper + if hasattr(self, 'hf_to_vllm_mapper') and self.hf_to_vllm_mapper is hf_to_vllm_mapper: + # Same mapper instance, no need to recreate + return + + # Store the original mapper + self.hf_to_vllm_mapper = hf_to_vllm_mapper + + # Create a reverse mapper (vLLM to HF) + new_to_orig_prefix = {} + + # Try different ways to get the mapping based on WeightsMapper implementation + # Method 1: Check for public or private attribute with possible names + mapping_attrs = ['orig_to_new_prefix', '_orig_to_new_prefix'] # Check both public and private variants + orig_to_new_prefix = {} + + for attr_name in mapping_attrs: + if hasattr(hf_to_vllm_mapper, attr_name): + orig_to_new_prefix = getattr(hf_to_vllm_mapper, attr_name) + break + + # Create reverse mapping + for orig_prefix, new_prefix in orig_to_new_prefix.items(): + # Reverse the mapping direction + new_to_orig_prefix[new_prefix] = orig_prefix + + # Combine with manual mappings if available + combined_mapping = new_to_orig_prefix.copy() + if hasattr(self, 'model_type') and self.model_type in QUANT_MODEL_PREFIX_MAPPINGS: + manual_mapping = QUANT_MODEL_PREFIX_MAPPINGS[self.model_type] + # Manual mapping is already in vLLM to HF direction, no need to reverse + combined_mapping.update(manual_mapping) + + # Create and store the reverse WeightsMapper instance + self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=combined_mapping) + + # Debug info + if not new_to_orig_prefix: + logger.warning("No reverse mapping found for WeightsMapper. Using manual mappings if available.") + else: + logger.debug(f"Created reverse mapping: {combined_mapping}") + def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: - # TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented + # Store model_type for backward compatibility mappings + self.model_type = model_type + + # Use the reverse mapper (vLLM to HF) if available + if hasattr(self, 'vllm_to_hf_mapper') and self.vllm_to_hf_mapper: + return self.vllm_to_hf_mapper._map_name(prefix) + + # Fall back to manual mapping for backward compatibility (simplified) + # This is only used if apply_vllm_mapper wasn't called or failed prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) if prefix_mapping: - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping) - return hf_to_vllm_mapper._map_name(prefix) + # Create a simple mapper on the fly (no caching since this should be rare) + mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping) + return mapper._map_name(prefix) + return prefix def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -454,9 +498,6 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["Qua from vllm.model_executor.layers.attention import Attention - if model_type != "kimi_k2": - if prefix.startswith("language_model"): - prefix = prefix.split(".", 1)[-1] if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): # Delayed import to avoid circular import From 465da169398d82cfea8943a237e94e4a879d0172 Mon Sep 17 00:00:00 2001 From: Matrix_K Date: Mon, 9 Mar 2026 02:12:37 +0000 Subject: [PATCH 2/3] fix pre-commit error of quantization refractor Signed-off-by: Matrix_K --- vllm_ascend/quantization/modelslim_config.py | 41 +++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 38ba5580af0..fa4485f0c03 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -361,7 +361,10 @@ def __init__(self, quant_config: dict[str, Any] | None = None): new_k = k.replace("weight_packed", "weight") extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) - self.model_type = None + # Initialize attributes for type checking + self.model_type: str | None = None + self.hf_to_vllm_mapper: WeightsMapper | None = None + self.vllm_to_hf_mapper: WeightsMapper | None = None def __repr__(self) -> str: return "AscendModelSlimConfig:\n" + super().__repr__() @@ -401,65 +404,65 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None: def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): """Apply the vLLM model-specific mapper to this quantization config. - + This method is called by vLLM to apply the model-specific weight mapper to the quantization configuration. It creates a reverse mapper to convert vLLM prefixes back to HF format for looking up keys in quant_config.json. - + Args: hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM that contains model-specific prefix mappings (HF to vLLM). """ # Check if we already have a valid vllm_to_hf_mapper for this hf_to_vllm_mapper - if hasattr(self, 'hf_to_vllm_mapper') and self.hf_to_vllm_mapper is hf_to_vllm_mapper: + if hasattr(self, "hf_to_vllm_mapper") and self.hf_to_vllm_mapper is hf_to_vllm_mapper: # Same mapper instance, no need to recreate return - + # Store the original mapper self.hf_to_vllm_mapper = hf_to_vllm_mapper - + # Create a reverse mapper (vLLM to HF) new_to_orig_prefix = {} - + # Try different ways to get the mapping based on WeightsMapper implementation # Method 1: Check for public or private attribute with possible names - mapping_attrs = ['orig_to_new_prefix', '_orig_to_new_prefix'] # Check both public and private variants + mapping_attrs = ["orig_to_new_prefix", "_orig_to_new_prefix"] orig_to_new_prefix = {} - + for attr_name in mapping_attrs: if hasattr(hf_to_vllm_mapper, attr_name): orig_to_new_prefix = getattr(hf_to_vllm_mapper, attr_name) break - + # Create reverse mapping for orig_prefix, new_prefix in orig_to_new_prefix.items(): # Reverse the mapping direction new_to_orig_prefix[new_prefix] = orig_prefix - + # Combine with manual mappings if available combined_mapping = new_to_orig_prefix.copy() - if hasattr(self, 'model_type') and self.model_type in QUANT_MODEL_PREFIX_MAPPINGS: + if hasattr(self, "model_type") and self.model_type in QUANT_MODEL_PREFIX_MAPPINGS: manual_mapping = QUANT_MODEL_PREFIX_MAPPINGS[self.model_type] # Manual mapping is already in vLLM to HF direction, no need to reverse combined_mapping.update(manual_mapping) - + # Create and store the reverse WeightsMapper instance self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=combined_mapping) - + # Debug info if not new_to_orig_prefix: logger.warning("No reverse mapping found for WeightsMapper. Using manual mappings if available.") else: logger.debug(f"Created reverse mapping: {combined_mapping}") - + def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: # Store model_type for backward compatibility mappings self.model_type = model_type - + # Use the reverse mapper (vLLM to HF) if available - if hasattr(self, 'vllm_to_hf_mapper') and self.vllm_to_hf_mapper: + if hasattr(self, "vllm_to_hf_mapper") and self.vllm_to_hf_mapper: return self.vllm_to_hf_mapper._map_name(prefix) - + # Fall back to manual mapping for backward compatibility (simplified) # This is only used if apply_vllm_mapper wasn't called or failed prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) @@ -467,7 +470,7 @@ def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: # Create a simple mapper on the fly (no caching since this should be rare) mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping) return mapper._map_name(prefix) - + return prefix def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: From 44f6a1768462e76a656b0a45cd11ae001a10e690 Mon Sep 17 00:00:00 2001 From: Feng-xiaosuo Date: Tue, 17 Mar 2026 19:59:24 +0800 Subject: [PATCH 3/3] Update vllm_ascend/quantization/modelslim_config.py Co-authored-by: Wang Kunpeng <1289706727@qq.com> Signed-off-by: Feng-xiaosuo --- vllm_ascend/quantization/modelslim_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 19c91d783c0..570b50736d6 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -467,6 +467,7 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None: return ASCEND_QUANTIZATION_METHOD return None + # TODO: Modify the key values in self.quant_description instead of flipping the hf_to_vllm_mapper def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): """Apply the vLLM model-specific mapper to this quantization config.