-
Notifications
You must be signed in to change notification settings - Fork 975
[Bugfix]Fix deepseek 3.2 C8 precision by revert quantization layers #7628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 17 commits
8411e39
123c17c
22b4215
cdac595
cf4faf8
60339c8
3cceeb3
4a8393d
5c0c566
706e904
20c055e
56a949d
d7390ed
55cc8ba
0953171
4c7dbd9
455bc6b
ff51fc6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,8 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import os | ||
| from typing import TYPE_CHECKING | ||
| import re | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from vllm.logger import logger | ||
| from vllm.utils.math_utils import cdiv | ||
|
|
@@ -158,13 +159,69 @@ def __init__(self, vllm_config: "VllmConfig"): | |
| and use_sparse | ||
| and get_ascend_device_type() != AscendDeviceType.A5 | ||
| ) | ||
|
|
||
| quant_config = getattr(vllm_config, "quant_config", None) | ||
| self._sparse_c8_layer_ids, self._sparse_c8_layer_names = self._parse_sparse_c8_layers_from_quant_config( | ||
| quant_config | ||
| ) | ||
| self._sparse_c8_layer_filter_enabled = self._has_sparse_c8_layer_config(quant_config) | ||
| self.enable_sp_by_pass = ( | ||
| vllm_config.model_config is not None | ||
| and not vllm_config.model_config.enforce_eager | ||
| and vllm_config.compilation_config.pass_config.enable_sp | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _has_sparse_c8_layer_config(quant_config: Any) -> bool: | ||
| quant_description = getattr(quant_config, "quant_description", None) | ||
| if not isinstance(quant_description, dict): | ||
| return False | ||
| return any(isinstance(key, str) and key.endswith(".indexer.quant_type") for key in quant_description) | ||
|
|
||
| @classmethod | ||
| def _parse_sparse_c8_layers_from_quant_config(cls, quant_config: Any) -> tuple[set[int], set[str]]: | ||
| quant_description = getattr(quant_config, "quant_description", None) | ||
| if not isinstance(quant_description, dict): | ||
| return set(), set() | ||
|
|
||
| layer_ids: set[int] = set() | ||
| layer_names: set[str] = set() | ||
| suffix = ".indexer.quant_type" | ||
| for key, value in quant_description.items(): | ||
| if not isinstance(key, str) or not key.endswith(suffix): | ||
| continue | ||
| if value != "INT8_DYNAMIC": | ||
| continue | ||
| layer_name = key[: -len(suffix)].rstrip(".") | ||
| if not layer_name: | ||
| continue | ||
| layer_names.add(layer_name) | ||
| layer_ids.update(cls._extract_layer_ids(layer_name)) | ||
| return layer_ids, layer_names | ||
|
|
||
| @staticmethod | ||
| def _extract_layer_ids(layer_name: str) -> set[int]: | ||
| return {int(match) for match in re.findall(r"(?:^|\.)(\d+)(?:\.|$)", layer_name)} | ||
|
||
|
|
||
| def is_sparse_c8_layer(self, layer_name: str | None) -> bool: | ||
| if not self.enable_sparse_c8: | ||
| return False | ||
| if not self._sparse_c8_layer_filter_enabled: | ||
| return True | ||
| if layer_name is None: | ||
| return False | ||
|
|
||
| normalized_layer_name = layer_name.rstrip(".") | ||
| if any( | ||
| normalized_layer_name == candidate or normalized_layer_name.startswith(f"{candidate}.") | ||
| for candidate in self._sparse_c8_layer_names | ||
| ): | ||
| return True | ||
|
|
||
| layer_ids = self._extract_layer_ids(normalized_layer_name) | ||
| return any(layer_id in self._sparse_c8_layer_ids for layer_id in layer_ids) | ||
|
|
||
|
|
||
|
|
||
| @staticmethod | ||
| def _get_compile_ranges(compilation_config): | ||
| return compilation_config.compile_ranges_endpoints or [] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -377,10 +377,14 @@ def get_quant_type_for_layer( | |
| if packed_modules_mapping is None: | ||
| packed_modules_mapping = dict() | ||
| # Attention | ||
| if layer_type == "attention" and "fa_quant_type" in quant_description: | ||
| return quant_description["fa_quant_type"] | ||
| if layer_type == "attention" and "indexer_quant_type" in quant_description: | ||
| return quant_description["indexer_quant_type"] | ||
| if layer_type == "attention": | ||
| layer_indexer_quant_type = quant_description.get(f"{prefix}.indexer.quant_type") | ||
| if layer_indexer_quant_type is not None: | ||
| return layer_indexer_quant_type | ||
|
Comment on lines
+381
to
+383
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if "fa_quant_type" in quant_description: | ||
| return quant_description["fa_quant_type"] | ||
| if "indexer_quant_type" in quant_description: | ||
| return quant_description["indexer_quant_type"] | ||
| # Linear / MoE | ||
| return get_linear_quant_type(quant_description, prefix, packed_modules_mapping) | ||
|
|
||
|
|
@@ -646,7 +650,7 @@ def is_indexer_quant_layer(self, prefix): | |
| if layer_id_str.isdigit() and int(layer_id_str) in self.indexer_quant_layers: | ||
| return True | ||
| return False | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this |
||
| def enabling_fa_quant(self, vllm_config, layer_name) -> bool: | ||
| is_decode_instance = ( | ||
| vllm_config.kv_transfer_config is not None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2673,6 +2673,13 @@ def _get_layer_kv_cache_specs(self, kv_cache_config: KVCacheConfig) -> dict[str, | |
| layer_kv_cache_spec[layer_name] = group_spec | ||
| return layer_kv_cache_spec | ||
|
|
||
| def _is_sparse_c8_layer(self, layer_name: str) -> bool: | ||
| return bool(self.use_sparse and self.ascend_config.is_sparse_c8_layer(layer_name)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just |
||
|
|
||
| @staticmethod | ||
| def _kv_cache_spec_uses_sparse_c8(kv_cache_spec: KVCacheSpec) -> bool: | ||
| return isinstance(kv_cache_spec, MLAAttentionSpec) and bool(getattr(kv_cache_spec, "cache_sparse_c8", False)) | ||
|
|
||
| def _get_attention_kv_cache_dims(self, layer_name: str, kv_cache_spec: AttentionSpec) -> tuple[int, int]: | ||
| if isinstance(kv_cache_spec, MLAAttentionSpec): | ||
| attn_layers = get_layers_from_vllm_config( | ||
|
|
@@ -2751,11 +2758,12 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str | |
| if self.use_sparse: | ||
| # for deepseek v3.2, we split the kv cache according to the corresponding ratio | ||
| kv_cache_spec = layer_kv_cache_spec[layer_name] | ||
| current_sparse_c8 = self._kv_cache_spec_uses_sparse_c8(kv_cache_spec) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Introducing |
||
| sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio | ||
| k_tensor_split_factor = sparse_kv_cache_ratio[0] | ||
| v_tensor_split_factor = sparse_kv_cache_ratio[1] | ||
| dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] | ||
| dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] | ||
| dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] if current_sparse_c8 else None | ||
| else: | ||
| k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec) | ||
| assert k_dim > 0 and v_dim > 0 | ||
|
|
@@ -2777,7 +2785,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str | |
| #### for deepseek sparse attention | ||
| if self.use_sparse: | ||
| dsa_k_tensor_size = int(kv_cache_tensor.size // dsa_k_tensor_split_factor) | ||
| if self.use_sparse_c8_indexer: | ||
| if self.use_sparse and current_sparse_c8: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition for calculating |
||
| dsa_k_scale_tensor_size = int(kv_cache_tensor.size // dsa_k_scale_tensor_split_factor) | ||
|
|
||
| # for other attentions, e.g., self_attn, sliding window attn | ||
|
|
@@ -2814,7 +2822,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str | |
| # shared the attn kvcache for all shared layers | ||
| if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: | ||
| if self.use_sparse: | ||
| if self.use_sparse_c8_indexer: | ||
| if current_sparse_c8: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updating the condition to |
||
| kv_cache_raw_tensors[layer_name_inner] = ( | ||
| k_tensor, v_tensor, dsa_k_tensor, dsa_k_scale_tensor | ||
| ) | ||
|
|
@@ -2862,7 +2870,8 @@ def _reshape_kv_cache_tensors( | |
| # encounter OOM issue | ||
| if isinstance(current_kv_cache_spec, AttentionSpec): | ||
| if self.use_sparse: | ||
| if self.use_sparse_c8_indexer: | ||
| current_sparse_c8 = self._kv_cache_spec_uses_sparse_c8(current_kv_cache_spec) | ||
| if current_sparse_c8: | ||
|
Comment on lines
+2873
to
+2874
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The introduction of |
||
| raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore | ||
| layer_name] | ||
| assert raw_dsa_k_tensor is not None | ||
|
|
@@ -2964,7 +2973,7 @@ def _reshape_kv_cache_tensors( | |
| current_kv_cache_spec.num_kv_heads, | ||
| self.model_config.hf_text_config.index_head_dim, | ||
| ) | ||
| if self.use_sparse_c8_indexer: | ||
| if current_sparse_c8: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing the condition to |
||
| # dsa_k | ||
| dsa_k_cache = raw_dsa_k_tensor.view(self.c8_k_cache_dtype).view(dsa_k_cache_shape) | ||
| # dsa_k_scale | ||
|
|
@@ -3267,7 +3276,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: | |
| sparse_head_dim=self.sparse_head_dim, | ||
| dtype=self.kv_cache_dtype, | ||
| cache_dtype_str=self.vllm_config.cache_config.cache_dtype, | ||
| cache_sparse_c8=self.use_sparse_c8_indexer, | ||
| cache_sparse_c8=self._is_sparse_c8_layer(layer_name), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
| elif spec := attn_module.get_kv_cache_spec(self.vllm_config): | ||
| assert isinstance(spec, MLAAttentionSpec) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plz add e2e test in tests/e2e/multicard/2-cards/test_offline_inference_distributed.py