diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 151109c53c1..455cda47256 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -47,57 +47,6 @@ # The config filename that ModelSlim generates after quantizing a model. MODELSLIM_CONFIG_FILENAME = "quant_model_description.json" -# key: model_type -# value: vLLM prefix -> HF prefix mapping (used to convert vLLM layer names to HF format -# for looking up keys in quant_model_description.json) -QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = { - "qwen3_vl_moe": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, - "qwen3_vl": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, - "kimi_k25": { - "mm_projector.linear_1": "mm_projector.proj.0", - "mm_projector.linear_2": "mm_projector.proj.2", - }, - "qwen3_omni_moe": { - "language_model.lm_head.": "thinker.lm_head.", - "language_model.model.": "thinker.model.", - "visual.": "thinker.visual.", - }, - "qwen2_5_omni": { - "language_model.lm_head.": "thinker.lm_head.", - "language_model.model.": "thinker.model.", - "visual.": "thinker.visual.", - }, - "qwen2_5_omni_text": { - "language_model.": "thinker.", - "language_model.lm_head.": "thinker.lm_head.", - "language_model.model.": "thinker.model.", - }, - "glm4v_moe": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, - "glm4v_moe_text": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, - "kimi_k2": { - "language_model.layers.": "language_model.model.layers.", - # mm projector - "mm_projector.proj.0": "mm_projector.linear_1", - "mm_projector.proj.2": "mm_projector.linear_2", - }, -} - # key: model_type # value: dict of fused module name -> list of original module names packed_modules_model_mapping: dict[str, dict[str, list[str]]] = { @@ -309,19 +258,6 @@ def get_packed_modules_mapping(model_type: str) -> dict[str, list[str]]: return packed_modules_model_mapping.get(model_type, {}) -def get_prefix_mapping(model_type: str) -> dict[str, str]: - """Get prefix mapping for a model type. - - Args: - model_type: The model type string (e.g., "qwen3_vl_moe"). - - Returns: - Dictionary mapping original prefixes to new prefixes. - Returns empty dict if model_type is not found. - """ - return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {}) - - def get_linear_quant_type( quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any] ) -> str | None: @@ -426,21 +362,10 @@ class AscendModelSlimConfig(QuantizationConfig): def __init__(self, quant_config: dict[str, Any] | None = None): super().__init__() self.quant_description = quant_config if quant_config is not None else {} - # TODO(whx): remove this adaptation after adding "shared_head" - # to prefix of DeepSeekShareHead in vLLM. - extra_quant_dict = {} - for k in self.quant_description: - if "shared_head" in k: - new_k = k.replace(".shared_head.", ".") - extra_quant_dict[new_k] = self.quant_description[k] - if "weight_packed" in k: - new_k = k.replace("weight_packed", "weight") - extra_quant_dict[new_k] = self.quant_description[k] - self.quant_description.update(extra_quant_dict) - # Initialize attributes for type checking + self._apply_extra_quant_adaptations() self.model_type: str | None = None self.hf_to_vllm_mapper: WeightsMapper | None = None - self.vllm_to_hf_mapper: WeightsMapper | None = None + self._mapper_applied = False self._add_kvcache_quant_metadata() def __repr__(self) -> str: @@ -479,73 +404,31 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None: return ASCEND_QUANTIZATION_METHOD return None - # TODO: Modify the key values in self.quant_description instead of flipping the hf_to_vllm_mapper def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): """Apply the vLLM model-specific mapper to this quantization config. This method is called by vLLM to apply the model-specific weight mapper - to the quantization configuration. It creates a reverse mapper to convert - vLLM prefixes back to HF format for looking up keys in quant_config.json. + to the quantization configuration. It directly uses the forward mapping + (HF -> vLLM) to transform keys in quant_description from HF format to + vLLM format. Args: hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM that contains model-specific prefix mappings (HF to vLLM). """ - # Check if we already have a valid vllm_to_hf_mapper for this hf_to_vllm_mapper - if hasattr(self, "hf_to_vllm_mapper") and self.hf_to_vllm_mapper is hf_to_vllm_mapper: - # Same mapper instance, no need to recreate + if self._mapper_applied and self.hf_to_vllm_mapper is hf_to_vllm_mapper: return - # Store the original mapper self.hf_to_vllm_mapper = hf_to_vllm_mapper + self._mapper_applied = True - # Try different ways to get the mapping based on WeightsMapper implementation - mapping_attrs = ["orig_to_new_prefix"] - orig_to_new_prefix = {} - - for attr_name in mapping_attrs: - if hasattr(hf_to_vllm_mapper, attr_name): - orig_to_new_prefix = getattr(hf_to_vllm_mapper, attr_name) - break - - # Create reverse mapping (vLLM -> HF), skipping empty values - vllm_to_hf_mapping = {} - for orig_prefix, new_prefix in orig_to_new_prefix.items(): - # Skip empty values to avoid invalid keys in reverse mapping - if new_prefix: - vllm_to_hf_mapping[new_prefix] = orig_prefix - - # Create and store the reverse WeightsMapper instance - if vllm_to_hf_mapping: - self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=vllm_to_hf_mapping) - logger.debug(f"Created reverse mapping from hf_to_vllm_mapper: {vllm_to_hf_mapping}") - else: - logger.info("No valid reverse mapping found for WeightsMapper.") + if self.quant_description: + self.quant_description = hf_to_vllm_mapper.apply_dict(self.quant_description) + self._add_kvcache_quant_metadata() + logger.info("Applied hf_to_vllm_mapper to quant_description keys") def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: - # Store model_type for reference self.model_type = model_type - - # Check if manual mapping exists for this model type - # Manual mapping takes priority and is used exclusively to avoid conflicts - if model_type in QUANT_MODEL_PREFIX_MAPPINGS: - manual_mapping = QUANT_MODEL_PREFIX_MAPPINGS[model_type] - # Manual mapping is already in vLLM -> HF direction, use directly - mapper = WeightsMapper(orig_to_new_prefix=manual_mapping) - return mapper._map_name(prefix) - - # Use the reverse mapper (vLLM to HF) if available - if hasattr(self, "vllm_to_hf_mapper") and self.vllm_to_hf_mapper: - return self.vllm_to_hf_mapper._map_name(prefix) - - # Fall back to manual mapping for backward compatibility (simplified) - # This is only used if apply_vllm_mapper wasn't called or failed - prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) - if prefix_mapping: - # Manual mapping is already in vLLM -> HF direction, use directly - mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping) - return mapper._map_name(prefix) - return prefix def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: