Skip to content
61 changes: 59 additions & 2 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import re
from typing import TYPE_CHECKING, Any

from vllm.logger import logger
from vllm.utils.math_utils import cdiv
Expand Down Expand Up @@ -158,13 +159,69 @@ def __init__(self, vllm_config: "VllmConfig"):
and use_sparse
and get_ascend_device_type() != AscendDeviceType.A5
)

quant_config = getattr(vllm_config, "quant_config", None)
self._sparse_c8_layer_ids, self._sparse_c8_layer_names = self._parse_sparse_c8_layers_from_quant_config(
quant_config
)
self._sparse_c8_layer_filter_enabled = self._has_sparse_c8_layer_config(quant_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plz add e2e test in tests/e2e/multicard/2-cards/test_offline_inference_distributed.py

self.enable_sp_by_pass = (
vllm_config.model_config is not None
and not vllm_config.model_config.enforce_eager
and vllm_config.compilation_config.pass_config.enable_sp
)

@staticmethod
def _has_sparse_c8_layer_config(quant_config: Any) -> bool:
quant_description = getattr(quant_config, "quant_description", None)
if not isinstance(quant_description, dict):
return False
return any(isinstance(key, str) and key.endswith(".indexer.quant_type") for key in quant_description)

@classmethod
def _parse_sparse_c8_layers_from_quant_config(cls, quant_config: Any) -> tuple[set[int], set[str]]:
quant_description = getattr(quant_config, "quant_description", None)
if not isinstance(quant_description, dict):
return set(), set()

layer_ids: set[int] = set()
layer_names: set[str] = set()
suffix = ".indexer.quant_type"
for key, value in quant_description.items():
if not isinstance(key, str) or not key.endswith(suffix):
continue
if value != "INT8_DYNAMIC":
continue
layer_name = key[: -len(suffix)].rstrip(".")
if not layer_name:
continue
layer_names.add(layer_name)
layer_ids.update(cls._extract_layer_ids(layer_name))
return layer_ids, layer_names

@staticmethod
def _extract_layer_ids(layer_name: str) -> set[int]:
return {int(match) for match in re.findall(r"(?:^|\.)(\d+)(?:\.|$)", layer_name)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz use extract_layer_index instead


def is_sparse_c8_layer(self, layer_name: str | None) -> bool:
if not self.enable_sparse_c8:
return False
if not self._sparse_c8_layer_filter_enabled:
return True
if layer_name is None:
return False

normalized_layer_name = layer_name.rstrip(".")
if any(
normalized_layer_name == candidate or normalized_layer_name.startswith(f"{candidate}.")
for candidate in self._sparse_c8_layer_names
):
return True

layer_ids = self._extract_layer_ids(normalized_layer_name)
return any(layer_id in self._sparse_c8_layer_ids for layer_id in layer_ids)



@staticmethod
def _get_compile_ranges(compilation_config):
return compilation_config.compile_ranges_endpoints or []
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def __init__(
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
self.layer_name = kwargs.get("layer_name")

# indexer param
self.n_head: int = self.indexer.n_head # 64
Expand All @@ -431,7 +432,7 @@ def __init__(
self.use_torch_npu_lightning_indexer = True

# dsa c8
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
self.use_sparse_c8_indexer = ascend_config.is_sparse_c8_layer(self.layer_name)
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/patch/platform/patch_kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,18 @@ def merge(cls, specs: list[Self]) -> Self:
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same quantization method."
)
cache_sparse_c8_set = set(spec.cache_sparse_c8 for spec in specs)
assert len(cache_sparse_c8_set) == 1, (
"All attention layers in the same KV cache group must use the same sparse C8 setting."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
sparse_head_dim=specs[0].sparse_head_dim,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
cache_sparse_c8=specs[0].cache_sparse_c8,
cache_sparse_c8=cache_sparse_c8_set.pop(),
)


Expand Down
14 changes: 9 additions & 5 deletions vllm_ascend/quantization/modelslim_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,14 @@ def get_quant_type_for_layer(
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention
if layer_type == "attention" and "fa_quant_type" in quant_description:
return quant_description["fa_quant_type"]
if layer_type == "attention" and "indexer_quant_type" in quant_description:
return quant_description["indexer_quant_type"]
if layer_type == "attention":
layer_indexer_quant_type = quant_description.get(f"{prefix}.indexer.quant_type")
if layer_indexer_quant_type is not None:
return layer_indexer_quant_type
Comment on lines +381 to +383
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Prioritizing layer_indexer_quant_type over general fa_quant_type or indexer_quant_type allows for more granular and layer-specific quantization configurations. This improves flexibility and precision in applying quantization settings.

if "fa_quant_type" in quant_description:
return quant_description["fa_quant_type"]
if "indexer_quant_type" in quant_description:
return quant_description["indexer_quant_type"]
# Linear / MoE
return get_linear_quant_type(quant_description, prefix, packed_modules_mapping)

Expand Down Expand Up @@ -646,7 +650,7 @@ def is_indexer_quant_layer(self, prefix):
if layer_id_str.isdigit() and int(layer_id_str) in self.indexer_quant_layers:
return True
return False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

def enabling_fa_quant(self, vllm_config, layer_name) -> bool:
is_decode_instance = (
vllm_config.kv_transfer_config is not None
Expand Down
21 changes: 15 additions & 6 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,6 +2673,13 @@ def _get_layer_kv_cache_specs(self, kv_cache_config: KVCacheConfig) -> dict[str,
layer_kv_cache_spec[layer_name] = group_spec
return layer_kv_cache_spec

def _is_sparse_c8_layer(self, layer_name: str) -> bool:
return bool(self.use_sparse and self.ascend_config.is_sparse_c8_layer(layer_name))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just self.ascend_config.is_sparse_c8_layer(layer_name) is enough.


@staticmethod
def _kv_cache_spec_uses_sparse_c8(kv_cache_spec: KVCacheSpec) -> bool:
return isinstance(kv_cache_spec, MLAAttentionSpec) and bool(getattr(kv_cache_spec, "cache_sparse_c8", False))

def _get_attention_kv_cache_dims(self, layer_name: str, kv_cache_spec: AttentionSpec) -> tuple[int, int]:
if isinstance(kv_cache_spec, MLAAttentionSpec):
attn_layers = get_layers_from_vllm_config(
Expand Down Expand Up @@ -2751,11 +2758,12 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str
if self.use_sparse:
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
kv_cache_spec = layer_kv_cache_spec[layer_name]
current_sparse_c8 = self._kv_cache_spec_uses_sparse_c8(kv_cache_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Introducing current_sparse_c8 to conditionally set dsa_k_scale_tensor_split_factor is a critical correctness improvement. This ensures that the scale tensor split factor is only considered when sparse C8 is actively enabled for the current KV cache specification, preventing potential errors or incorrect memory allocation.

sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
k_tensor_split_factor = sparse_kv_cache_ratio[0]
v_tensor_split_factor = sparse_kv_cache_ratio[1]
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] if current_sparse_c8 else None
else:
k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec)
assert k_dim > 0 and v_dim > 0
Expand All @@ -2777,7 +2785,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str
#### for deepseek sparse attention
if self.use_sparse:
dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor)
if self.use_sparse_c8_indexer:
if self.use_sparse and current_sparse_c8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition for calculating dsa_k_scale_tensor_size has been updated to if self.use_sparse and current_sparse_c8:. This change correctly links the calculation to the current_sparse_c8 flag, ensuring that dsa_k_scale_tensor_size is only computed when sparse C8 is enabled and relevant, which is essential for accurate memory management.

dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor)

# for other attentions, e.g., self_attn, sliding window attn
Expand Down Expand Up @@ -2814,7 +2822,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str
# shared the attn kvcache for all shared layers
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
if self.use_sparse:
if self.use_sparse_c8_indexer:
if current_sparse_c8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Updating the condition to if current_sparse_c8: for assigning kv_cache_raw_tensors ensures that the 4-element tuple (including dsa_k_scale_tensor) is only used when sparse C8 is enabled. This prevents potential TypeError or ValueError if sparse_kv_cache_ratio[3] is None or if the tuple structure is not expected, which is a critical correctness fix.

kv_cache_raw_tensors[layer_name_inner] = (
k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor
)
Expand Down Expand Up @@ -2862,7 +2870,8 @@ def _reshape_kv_cache_tensors(
# encounter OOM issue
if isinstance(current_kv_cache_spec, AttentionSpec):
if self.use_sparse:
if self.use_sparse_c8_indexer:
current_sparse_c8 = self._kv_cache_spec_uses_sparse_c8(current_kv_cache_spec)
if current_sparse_c8:
Comment on lines +2873 to +2874
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The introduction of current_sparse_c8 and its use in the conditional if current_sparse_c8: statement is critical. This ensures that raw_dsa_k_scale_tensor is only unpacked when sparse C8 is enabled for the current KV cache specification. Without this, an attempt to unpack a 3-element tuple as 4 elements would result in a ValueError, leading to a crash.

raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
Expand Down Expand Up @@ -2964,7 +2973,7 @@ def _reshape_kv_cache_tensors(
current_kv_cache_spec.num_kv_heads,
self.model_config.hf_text_config.index_head_dim,
)
if self.use_sparse_c8_indexer:
if current_sparse_c8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Changing the condition to if current_sparse_c8: ensures that dsa_k_cache and dsa_k_scale_cache are only initialized when sparse C8 is active for the current KV cache specification. This maintains consistency with the new sparse C8 logic and prevents unnecessary resource allocation or incorrect data handling.

# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape)
# dsa_k_scale
Expand Down Expand Up @@ -3267,7 +3276,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
sparse_head_dim=self.sparse_head_dim,
dtype=self.kv_cache_dtype,
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
cache_sparse_c8=self.use_sparse_c8_indexer,
cache_sparse_c8=self._is_sparse_c8_layer(layer_name),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Updating cache_sparse_c8 to dynamically use self._is_sparse_c8_layer(layer_name) ensures that the MLAAttentionSpec accurately reflects the sparse C8 status for each specific layer. This is a crucial functional change for the correct implementation of layer-wise sparse C8 quantization.

)
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
assert isinstance(spec, MLAAttentionSpec)
Expand Down
Loading