Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 11 additions & 128 deletions vllm_ascend/quantization/modelslim_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,57 +47,6 @@
# The config filename that ModelSlim generates after quantizing a model.
MODELSLIM_CONFIG_FILENAME = "quant_model_description.json"

# key: model_type
# value: vLLM prefix -> HF prefix mapping (used to convert vLLM layer names to HF format
# for looking up keys in quant_model_description.json)
QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = {
"qwen3_vl_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"qwen3_vl": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"kimi_k25": {
"mm_projector.linear_1": "mm_projector.proj.0",
"mm_projector.linear_2": "mm_projector.proj.2",
},
"qwen3_omni_moe": {
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
"visual.": "thinker.visual.",
},
"qwen2_5_omni": {
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
"visual.": "thinker.visual.",
},
"qwen2_5_omni_text": {
"language_model.": "thinker.",
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
},
"glm4v_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"glm4v_moe_text": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"kimi_k2": {
"language_model.layers.": "language_model.model.layers.",
# mm projector
"mm_projector.proj.0": "mm_projector.linear_1",
"mm_projector.proj.2": "mm_projector.linear_2",
},
}

# key: model_type
# value: dict of fused module name -> list of original module names
packed_modules_model_mapping: dict[str, dict[str, list[str]]] = {
Expand Down Expand Up @@ -309,19 +258,6 @@ def get_packed_modules_mapping(model_type: str) -> dict[str, list[str]]:
return packed_modules_model_mapping.get(model_type, {})


def get_prefix_mapping(model_type: str) -> dict[str, str]:
"""Get prefix mapping for a model type.

Args:
model_type: The model type string (e.g., "qwen3_vl_moe").

Returns:
Dictionary mapping original prefixes to new prefixes.
Returns empty dict if model_type is not found.
"""
return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {})


def get_linear_quant_type(
quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any]
) -> str | None:
Expand Down Expand Up @@ -426,21 +362,10 @@ class AscendModelSlimConfig(QuantizationConfig):
def __init__(self, quant_config: dict[str, Any] | None = None):
super().__init__()
self.quant_description = quant_config if quant_config is not None else {}
# TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description:
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
if "weight_packed" in k:
new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
# Initialize attributes for type checking
self._apply_extra_quant_adaptations()
self.model_type: str | None = None
self.hf_to_vllm_mapper: WeightsMapper | None = None
self.vllm_to_hf_mapper: WeightsMapper | None = None
self._mapper_applied = False
self._add_kvcache_quant_metadata()

def __repr__(self) -> str:
Expand Down Expand Up @@ -479,73 +404,31 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None:
return ASCEND_QUANTIZATION_METHOD
return None

# TODO: Modify the key values in self.quant_description instead of flipping the hf_to_vllm_mapper
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
"""Apply the vLLM model-specific mapper to this quantization config.

This method is called by vLLM to apply the model-specific weight mapper
to the quantization configuration. It creates a reverse mapper to convert
vLLM prefixes back to HF format for looking up keys in quant_config.json.
to the quantization configuration. It directly uses the forward mapping
(HF -> vLLM) to transform keys in quant_description from HF format to
vLLM format.

Args:
hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM
that contains model-specific prefix mappings (HF to vLLM).
"""
# Check if we already have a valid vllm_to_hf_mapper for this hf_to_vllm_mapper
if hasattr(self, "hf_to_vllm_mapper") and self.hf_to_vllm_mapper is hf_to_vllm_mapper:
# Same mapper instance, no need to recreate
if self._mapper_applied and self.hf_to_vllm_mapper is hf_to_vllm_mapper:
return
Comment on lines +419 to 420
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for checking if the mapper has been applied is not fully robust. If apply_vllm_mapper is called a second time with a different mapper instance, the condition self.hf_to_vllm_mapper is hf_to_vllm_mapper will be false. The code would then proceed to apply the new mapping on the quant_description keys which have already been transformed, leading to incorrect behavior.

Since apply_dict modifies the quant_description state, re-applying a mapping is not safe. To prevent this potential bug, you should add a more robust check to error out if a different mapper is provided after the first application.

Suggested change
if self._mapper_applied and self.hf_to_vllm_mapper is hf_to_vllm_mapper:
return
if self._mapper_applied:
if self.hf_to_vllm_mapper is not hf_to_vllm_mapper:
raise RuntimeError(
"apply_vllm_mapper() has already been called with a different "
"mapper. Re-applying the mapping is not supported as it "
"modifies the quantization description in-place."
)
return


# Store the original mapper
self.hf_to_vllm_mapper = hf_to_vllm_mapper
self._mapper_applied = True

# Try different ways to get the mapping based on WeightsMapper implementation
mapping_attrs = ["orig_to_new_prefix"]
orig_to_new_prefix = {}

for attr_name in mapping_attrs:
if hasattr(hf_to_vllm_mapper, attr_name):
orig_to_new_prefix = getattr(hf_to_vllm_mapper, attr_name)
break

# Create reverse mapping (vLLM -> HF), skipping empty values
vllm_to_hf_mapping = {}
for orig_prefix, new_prefix in orig_to_new_prefix.items():
# Skip empty values to avoid invalid keys in reverse mapping
if new_prefix:
vllm_to_hf_mapping[new_prefix] = orig_prefix

# Create and store the reverse WeightsMapper instance
if vllm_to_hf_mapping:
self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=vllm_to_hf_mapping)
logger.debug(f"Created reverse mapping from hf_to_vllm_mapper: {vllm_to_hf_mapping}")
else:
logger.info("No valid reverse mapping found for WeightsMapper.")
if self.quant_description:
self.quant_description = hf_to_vllm_mapper.apply_dict(self.quant_description)
self._add_kvcache_quant_metadata()
logger.info("Applied hf_to_vllm_mapper to quant_description keys")

def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
# Store model_type for reference
self.model_type = model_type

# Check if manual mapping exists for this model type
# Manual mapping takes priority and is used exclusively to avoid conflicts
if model_type in QUANT_MODEL_PREFIX_MAPPINGS:
manual_mapping = QUANT_MODEL_PREFIX_MAPPINGS[model_type]
# Manual mapping is already in vLLM -> HF direction, use directly
mapper = WeightsMapper(orig_to_new_prefix=manual_mapping)
return mapper._map_name(prefix)

# Use the reverse mapper (vLLM to HF) if available
if hasattr(self, "vllm_to_hf_mapper") and self.vllm_to_hf_mapper:
return self.vllm_to_hf_mapper._map_name(prefix)

# Fall back to manual mapping for backward compatibility (simplified)
# This is only used if apply_vllm_mapper wasn't called or failed
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping:
# Manual mapping is already in vLLM -> HF direction, use directly
mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping)
return mapper._map_name(prefix)

return prefix

def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
Expand Down
Loading