diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 53423b4d6b2..1cb8b6c3d53 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from vllm.logger import logger from vllm.utils.math_utils import cdiv @@ -158,13 +158,66 @@ def __init__(self, vllm_config: "VllmConfig"): and use_sparse and get_ascend_device_type() != AscendDeviceType.A5 ) - + quant_config = getattr(vllm_config, "quant_config", None) + self._sparse_c8_layer_ids, self._sparse_c8_layer_names = self._parse_sparse_c8_layers_from_quant_config( + quant_config + ) + self._sparse_c8_layer_filter_enabled = self._has_sparse_c8_layer_config(quant_config) self.enable_sp_by_pass = ( vllm_config.model_config is not None and not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp ) + @staticmethod + def _has_sparse_c8_layer_config(quant_config: Any) -> bool: + quant_description = getattr(quant_config, "quant_description", None) + if not isinstance(quant_description, dict): + return False + return any(isinstance(key, str) and key.endswith(".indexer.quant_type") for key in quant_description) + + @classmethod + def _parse_sparse_c8_layers_from_quant_config(cls, quant_config: Any) -> tuple[set[int], set[str]]: + quant_description = getattr(quant_config, "quant_description", None) + if not isinstance(quant_description, dict): + return set(), set() + + layer_ids: set[int] = set() + layer_names: set[str] = set() + suffix = ".indexer.quant_type" + from vllm.model_executor.models.utils import extract_layer_index + + for key, value in quant_description.items(): + if not isinstance(key, str) or not key.endswith(suffix): + continue + if value != "INT8_DYNAMIC": + continue + layer_name = key[: -len(suffix)].rstrip(".") + if not layer_name: + continue + layer_names.add(layer_name) + layer_ids.update({extract_layer_index(layer_name)}) + return layer_ids, layer_names + + def is_sparse_c8_layer(self, layer_name: str | None) -> bool: + if not self.enable_sparse_c8: + return False + if not self._sparse_c8_layer_filter_enabled: + return True + if layer_name is None: + return False + + normalized_layer_name = layer_name.rstrip(".") + if any( + normalized_layer_name == candidate or normalized_layer_name.startswith(f"{candidate}.") + for candidate in self._sparse_c8_layer_names + ): + return True + from vllm.model_executor.models.utils import extract_layer_index + + layer_ids = {extract_layer_index(normalized_layer_name)} + return any(layer_id in self._sparse_c8_layer_ids for layer_id in layer_ids) + @staticmethod def _get_compile_ranges(compilation_config): return compilation_config.compile_ranges_endpoints or [] diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 7d787648385..3b608b6b70c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -415,6 +415,7 @@ def __init__( self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) + self.layer_name = kwargs.get("layer_name") # indexer param self.n_head: int = self.indexer.n_head # 64 @@ -431,7 +432,7 @@ def __init__( self.use_torch_npu_lightning_indexer = True # dsa c8 - self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8 + self.use_sparse_c8_indexer = ascend_config.is_sparse_c8_layer(self.layer_name) if self.use_sparse_c8_indexer: self.c8_k_cache_dtype = torch.int8 self.c8_k_scale_cache_dtype = torch.float16 diff --git a/vllm_ascend/patch/platform/patch_kv_cache_interface.py b/vllm_ascend/patch/platform/patch_kv_cache_interface.py index 3719a3c54a3..059a3b1ba26 100644 --- a/vllm_ascend/patch/platform/patch_kv_cache_interface.py +++ b/vllm_ascend/patch/platform/patch_kv_cache_interface.py @@ -124,6 +124,10 @@ def merge(cls, specs: list[Self]) -> Self: assert len(cache_dtype_str_set) == 1, ( "All attention layers in the same KV cache group must use the same quantization method." ) + cache_sparse_c8_set = set(spec.cache_sparse_c8 for spec in specs) + assert len(cache_sparse_c8_set) == 1, ( + "All attention layers in the same KV cache group must use the same sparse C8 setting." + ) return cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, @@ -131,7 +135,7 @@ def merge(cls, specs: list[Self]) -> Self: sparse_head_dim=specs[0].sparse_head_dim, dtype=specs[0].dtype, cache_dtype_str=cache_dtype_str_set.pop(), - cache_sparse_c8=specs[0].cache_sparse_c8, + cache_sparse_c8=cache_sparse_c8_set.pop(), ) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 0ad0bde4538..6badfe59a78 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -377,10 +377,14 @@ def get_quant_type_for_layer( if packed_modules_mapping is None: packed_modules_mapping = dict() # Attention - if layer_type == "attention" and "fa_quant_type" in quant_description: - return quant_description["fa_quant_type"] - if layer_type == "attention" and "indexer_quant_type" in quant_description: - return quant_description["indexer_quant_type"] + if layer_type == "attention": + layer_indexer_quant_type = quant_description.get(f"{prefix}.indexer.quant_type") + if layer_indexer_quant_type is not None: + return layer_indexer_quant_type + if "fa_quant_type" in quant_description: + return quant_description["fa_quant_type"] + if "indexer_quant_type" in quant_description: + return quant_description["indexer_quant_type"] # Linear / MoE return get_linear_quant_type(quant_description, prefix, packed_modules_mapping) @@ -640,13 +644,6 @@ def is_fa_quant_layer(self, prefix): return True return False - def is_indexer_quant_layer(self, prefix): - if self.enable_indexer_quant: - layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix)) - if layer_id_str.isdigit() and int(layer_id_str) in self.indexer_quant_layers: - return True - return False - def enabling_fa_quant(self, vllm_config, layer_name) -> bool: is_decode_instance = ( vllm_config.kv_transfer_config is not None @@ -655,6 +652,13 @@ def enabling_fa_quant(self, vllm_config, layer_name) -> bool: ) return bool(is_decode_instance and self.is_fa_quant_layer(layer_name)) + def is_indexer_quant_layer(self, prefix): + if self.enable_indexer_quant: + layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix)) + if layer_id_str.isdigit() and int(layer_id_str) in self.indexer_quant_layers: + return True + return False + def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config): if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): ori_dtype = model_config.dtype diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 4db163c7e0b..1070e4e849a 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1272,3 +1272,9 @@ def parse_layer_idx(prefix: str) -> int | None: """Extract the layer index from a module prefix string like 'model.layers.0.self_attn'.""" match = re.search(r"layers\.(\d+)", prefix) return int(match.group(1)) if match else None + + +def kv_cache_spec_uses_sparse_c8(kv_cache_spec) -> bool: + from vllm.v1.kv_cache_interface import MLAAttentionSpec + + return isinstance(kv_cache_spec, MLAAttentionSpec) and bool(getattr(kv_cache_spec, "cache_sparse_c8", False)) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 677e9925a82..b42ea4bd0e9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -124,6 +124,7 @@ global_stream, is_drafter_moe_model, is_moe_model, + kv_cache_spec_uses_sparse_c8, lmhead_tp_enable, set_weight_prefetch_method, ) @@ -2751,11 +2752,12 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str if self.use_sparse: # for deepseek v3.2, we split the kv cache according to the corresponding ratio kv_cache_spec = layer_kv_cache_spec[layer_name] + current_sparse_c8 = kv_cache_spec_uses_sparse_c8(kv_cache_spec) sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio k_tensor_split_factor = sparse_kv_cache_ratio[0] v_tensor_split_factor = sparse_kv_cache_ratio[1] dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] - dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] + dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] if current_sparse_c8 else None else: k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec) assert k_dim > 0 and v_dim > 0 @@ -2777,7 +2779,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str #### for deepseek sparse attention if self.use_sparse: dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor) - if self.use_sparse_c8_indexer: + if self.use_sparse and current_sparse_c8: dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor) # for other attentions, e.g., self_attn, sliding window attn @@ -2814,7 +2816,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str # shared the attn kvcache for all shared layers if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: if self.use_sparse: - if self.use_sparse_c8_indexer: + if current_sparse_c8: kv_cache_raw_tensors[layer_name_inner] = ( k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor ) @@ -2862,7 +2864,8 @@ def _reshape_kv_cache_tensors( # encounter OOM issue if isinstance(current_kv_cache_spec, AttentionSpec): if self.use_sparse: - if self.use_sparse_c8_indexer: + current_sparse_c8 = kv_cache_spec_uses_sparse_c8(current_kv_cache_spec) + if current_sparse_c8: raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore layer_name] assert raw_dsa_k_tensor is not None @@ -2964,7 +2967,7 @@ def _reshape_kv_cache_tensors( current_kv_cache_spec.num_kv_heads, self.model_config.hf_text_config.index_head_dim, ) - if self.use_sparse_c8_indexer: + if current_sparse_c8: # dsa_k dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape) # dsa_k_scale @@ -3267,7 +3270,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: sparse_head_dim=self.sparse_head_dim, dtype=self.kv_cache_dtype, cache_dtype_str=self.vllm_config.cache_config.cache_dtype, - cache_sparse_c8=self.use_sparse_c8_indexer, + cache_sparse_c8=self.ascend_config.is_sparse_c8_layer(layer_name), ) elif spec := attn_module.get_kv_cache_spec(self.vllm_config): assert isinstance(spec, MLAAttentionSpec)