diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index a551051bfd..50527d1945 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -121,7 +121,14 @@ jobs: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \ --ignore=tests/ut/test_platform.py \ - --ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py + --ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py \ + --ignore=tests/ut/core/test_scheduler.py \ + --ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \ + --ignore=tests/ut/kv_connector/test_mooncake_connector.py \ + --ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \ + --ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \ + --ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \ + --ignore=tests/ut/torchair/test_utils.py - name: Upload coverage to Codecov # only upload coverage when commits merged diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 7588e70ed9..90aede78d1 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -23,5 +23,7 @@ def register(): def register_model(): + import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa + from .models import register_model register_model() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 301e64242d..65ea3ea0d2 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -34,6 +34,8 @@ class AscendConfig: def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} + self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32" + self.use_sfa = self.is_deepseek_sfa torchair_graph_config = additional_config.get("torchair_graph_config", {}) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index cf92affd38..225d4b903a 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -73,7 +73,7 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, device: torch.device): self._update_attn_cache(max_seq_len, dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( - ).to(device) + ).to(device, non_blocking=True) def get_splitfuse_attn_mask( self, diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py new file mode 100644 index 0000000000..55282c8443 --- /dev/null +++ b/vllm_ascend/attention/sfa_v1.py @@ -0,0 +1,986 @@ +from dataclasses import dataclass +from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, + TypeVar) + +import torch +import torch_npu +from torch import nn +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import AttentionCGSupport + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendSFABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_SFA" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendSFAMetadata + + @staticmethod + def get_builder_cls(): + return AscendSFAMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["AscendSFAImpl"]: + return AscendSFAImpl + + +@dataclass +class AscendSFAPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + + attn_mask: torch.Tensor + query_lens: list[int] + seq_lens: list[int] + + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + sin: torch.Tensor + cos: torch.Tensor + chunked_context: Optional[ChunkedContextMetadata] = None + + +@dataclass +class AscendSFADecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: torch.Tensor + sin: torch.Tensor + cos: torch.Tensor + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendSFAMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendSFADecodeMetadata] = None + prefill: Optional[AscendSFAPrefillMetadata] = None + enable_dbo_across_dp: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendSFAMetadata"]: + """Split metadata for multi-stream with AscendSFAMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLAMetadata, + ) + + +M = TypeVar("M", bound=AscendSFAMetadata) + + +class AscendSFAMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendSFAMetadata] = None): + self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ + if metadata_cls is not None else AscendSFAMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" + + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= self.decode_threshold: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendSFAMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[: + num_actual_tokens].to( + device, + non_blocking=True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[reqs_start:].max().item() + max_seq_lens = seq_lens[reqs_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendSFAPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + actual_query_lens = torch.tensor(query_lens[reqs_start:], + dtype=torch.int32).npu() + query_lens_prefill_sfa = torch.cumsum(actual_query_lens, + dim=0).to(torch.int32) + seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu() + prefill_metadata = AscendSFAPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens_prefill_sfa, + seq_lens=seq_lens_prefill_sfa, + context_lens=seq_lens[reqs_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + + decode_metadata = None + if num_decodes > 0: + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( + torch.int32).npu() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + seq_lens_list = seq_lens.tolist() + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendSFADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + ) + + +class PrefillSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + + +class DecodeSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + # nope_cache: Optional[torch.Tensor] = None + # rope_cache: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + bsz: Optional[int] = None + + +class AscendSFAImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.indexer = kwargs['indexer'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_proj = kwargs.get('q_a_proj', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + if self.q_a_proj is not None: + self.q_b_proj = self.q_proj + else: + self.q_b_proj = None + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.enable_prefetch + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 + self.prefill_mask = None + + # indexer param + self.dim = self.indexer.dim + self.n_heads: int = self.indexer.n_heads # 64 + self.head_dim: int = self.indexer.head_dim # 128 + self.index_topk: int = self.indexer.index_topk # 2048 + self.wq_b = self.indexer.wq_b + self.wk = self.indexer.wk + self.weights_proj = self.indexer.weights_proj + self.k_norm = self.indexer.k_norm + self.softmax_scale = self.indexer.softmax_scale + + # Adapt torch air graph mode with spec decoding. + speculative_config = vllm_config.speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 + + self.cp_size = 1 + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() + + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + + def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv): + # SFA Preprocess: + # 1. Perform q_a_proj and q_a_layernorm to obtain q_c + # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split + # 3. If need_gather_q_kv, perform all_gather. + # 4. Preprocess decode tokens, write kv cache and get: + # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope + # 5. Preprocess prefill tokens, write kv cache and get: + # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + + num_decode_tokens = attn_metadata.num_decode_tokens + num_actual_tokens = attn_metadata.num_actual_tokens + if need_gather_q_kv: + # q_c = get_tp_group().all_gather(q_c, 0) + # kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + hidden_states = get_tp_group().all_gather(hidden_states, 0) + # hidden_states_decode = hidden_states[:num_decode_tokens] + # if self.q_a_proj is not None: + # npu_prefetch(self.q_a_proj.weight, + # hidden_states, + # enabled=self.enable_prefetch) + # ckq = self.q_a_proj(hidden_states) # q down + # q_c = self.q_a_layernorm(ckq) # q down layernorm + # else: + # q_c = hidden_states + + # kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv + # Process for shared_expert_dp + + decode_preprocess_res = None + prefill_preprocess_res = None + # Preprocess for decode tokens + if has_decode: + q_len = 1 + hidden_states_decode = hidden_states[:num_decode_tokens] + decode_kq = self.q_a_proj(hidden_states_decode) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + decode_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_decode) # c_kv + + # decode_q_c = q_c[:num_decode_tokens] + decode_slot_mapping = attn_metadata.slot_mapping[: + num_decode_tokens] + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + decode_q = self.q_b_proj(decode_q_c) + bsz, _ = decode_q.shape + decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim) + decode_q_nope, decode_q_pe = torch.split( + decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + decode_q_nope = decode_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + decode_q_nope = (torch.matmul(decode_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(bsz, q_len, + self.num_heads, + self.kv_lora_rank)) + + # stream2 kv + key_cache = kv_cache[0] + value_cache = kv_cache[1] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1) + decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + decode_kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + decode_slot_mapping.to(torch.int64), + value_cache, + key_cache, + c_kv_scale=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode='PA') # adapter NZ + # nz_block_size = 16 + # KVCACHE_NZ_DIM = 16 + # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) + # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) + + decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos, + sin) # BNSD + + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + topk_indices = self.indexer_select(hidden_states_decode, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache) + + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_rope) + decode_preprocess_res = DecodeSFAPreprocessResult( + q_nope=decode_q_nope, + q_pe=decode_q_pe, + # nope_cache = nope_cache, + # rope_cache = rope_cache, + topk_indices=topk_indices, + query_states=query_states, + key_states=key_states, + bsz=bsz, + ) + + # Preprocess for prefill tokens + if has_prefill: + bsz = 1 + + hidden_states_prefill = hidden_states[ + num_decode_tokens:num_actual_tokens] + prefill_kq = self.q_a_proj(hidden_states_prefill) # q down + prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm + prefill_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_prefill) # c_kv + + # prefill_q_c = q_c[ + # num_decode_tokens:num_actual_tokens] + prefill_slot_mapping = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + prefill_slot_mapping = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + # prefill_kv_no_split = kv_no_split[ + # num_decode_tokens:num_actual_tokens] + # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] + prefill_qr = prefill_q_c + prefill_q = self.q_b_proj(prefill_qr) + prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) + prefill_q_nope, prefill_q_pe = torch.split( + prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + prefill_q_nope = prefill_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + prefill_q_nope = (torch.matmul(prefill_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(-1, self.num_heads, + self.kv_lora_rank)) + prefill_q_pe = prefill_q_pe.unsqueeze(2) + + # stream2 kv + + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + cos_q, sin_q = cos, sin + + # cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + # sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + prefill_q_pe = torch_npu.npu_interleave_rope( + prefill_q_pe, cos_q, sin_q) # BNSD + prefill_q_pe = prefill_q_pe.squeeze(2) #BSH + # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? + + prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) + prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + prefill_latent_cache.view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), + self.kv_a_layernorm.weight, + cos.view(-1, 1, 1, self.qk_rope_head_dim), + sin.view(-1, 1, 1, self.qk_rope_head_dim), + prefill_slot_mapping.to(torch.int64), + rope_cache, + nope_cache, + k_rope_scale=None, + c_kv_scale=None, + k_rope_offset=None, + c_kv_offset=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA") + + topk_indices = self.indexer_select(x=hidden_states_prefill, + qr=prefill_qr, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + query_states = (prefill_q_nope, prefill_q_pe) + key_states = (prefill_k_nope, prefill_k_pe) + prefill_preprocess_res = PrefillSFAPreprocessResult( + q_nope=prefill_q_nope, + q_pe=prefill_q_pe, + topk_indices=topk_indices, + k_nope=prefill_k_nope, + k_pe=prefill_k_pe, + query_states=query_states, + key_states=key_states, + ) + + return decode_preprocess_res, prefill_preprocess_res + + def forward( + self, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output + num_actual_tokens = attn_metadata.num_actual_tokens + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + num_decode_tokens = attn_metadata.num_decode_tokens + # Inputs and outputs may be padded for CUDA graphs + output = output[:num_actual_tokens, ...] + o_proj_input_shape = (num_actual_tokens, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + # SFA Preprocess + decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess( + hidden_states, kv_cache, attn_metadata, need_gather_q_kv) + + if decode_preprocess_res is not None: + # bsz, q_len, _, _ = query_states[0].shape + decode_attn_output = self.apply_attention_fusion( + query_states=decode_preprocess_res.query_states, + key_states=decode_preprocess_res.key_states, + attn_metadata=attn_metadata, + topk_indices=decode_preprocess_res.topk_indices) + o_proj_input[:num_decode_tokens] = decode_attn_output + + if prefill_preprocess_res is not None: + prefill_attn_output = self.apply_attention_fusion( + query_states=prefill_preprocess_res.query_states, + key_states=prefill_preprocess_res.key_states, + attn_metadata=attn_metadata, + topk_indices=prefill_preprocess_res.topk_indices) + o_proj_input[num_decode_tokens:] = prefill_attn_output + + output[...] = self.mla_epilog(o_proj_input, absorb=True) + return output + + def apply_attention_fusion(self, query_states, key_states, topk_indices, + attn_metadata: M): + # repeat k/v heads if n_kv_heads < n_heads + q_nope, q_pe = query_states + k_nope, k_rope = key_states + + if attn_metadata.prefill is not None: + + prefill_metadata = attn_metadata.prefill + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=prefill_metadata.block_table, + actual_seq_lengths_query=prefill_metadata.query_lens, + actual_seq_lengths_kv=prefill_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + + elif attn_metadata.decode is not None: + decode_metadata = attn_metadata.decode + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=attn_metadata.decode.block_table, + actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=decode_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.squeeze(1) + + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul(slc_fa_fusion, + self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + # Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and + # with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1. + # after reshape: [T(bs*q_len), 1, N//attn_tp_size*D] + # attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim) + + return attn_output + + def mla_epilog(self, + attn_output: torch.Tensor = None, + absorb: bool = False): + # TODO: need to check + attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0], + -1), + is_prefill=True, + is_force_scatter=False) + + return attn_output + + def indexer_select( + self, + x: torch.Tensor, + qr: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + ): + if attn_metadata.prefill is not None: + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + actual_seq_lengths_query = attn_metadata.prefill.query_lens + actual_seq_lengths_key = attn_metadata.prefill.seq_lens + block_table = attn_metadata.prefill.block_table + elif attn_metadata.decode is not None: + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q + actual_seq_lengths_key = attn_metadata.decode.seq_lens + block_table = attn_metadata.decode.block_table + + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + # q process in new stream + q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + k = self.k_norm(k_proj).unsqueeze(1) + k_pe, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + + weights = self.weights_proj(x) + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3) + return topk_indices diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 3782bc7d1c..6169328a7e 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -493,8 +493,11 @@ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): assert self.local_agent_metadata is not None kv_cache_dtype = first_kv_cache.dtype self.use_mla: bool = first_kv_cache_tuple[0].size( - -1) != first_kv_cache_tuple[1].size(-1) + -1) != first_kv_cache_tuple[1].size(-1) and len( + first_kv_cache_tuple) == 2 + self.use_sfa: bool = len(first_kv_cache_tuple) == 3 # MLA case. [2 (k_normed, k_pe), num_blocks, ...] + # SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...] # MHA case. [2 (k and v), num_blocks, ...] self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] @@ -540,6 +543,58 @@ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): raise RuntimeError( f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" ) + elif self.use_sfa: + cache_k_normed_addr_list = [] + cache_k_pe_addr_list = [] + cache_k_idx_addr_list = [] + k_normed = None + k_pe = None + k_idx = None + for cache_or_caches in kv_caches.values(): + assert len(cache_or_caches) > 1 + k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[ + 1], cache_or_caches[2] + cache_k_normed_addr_list.append(k_normed.data_ptr()) + cache_k_pe_addr_list.append(k_pe.data_ptr()) + cache_k_idx_addr_list.append(k_idx.data_ptr()) + self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list, + cache_k_idx_addr_list) + + cache_desc_k_normed = CacheDesc( + len(self.cache_addr[0]), [*k_normed.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_pe = CacheDesc( + len(self.cache_addr[1]), [*k_pe.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_idx = CacheDesc( + len(self.cache_addr[2]), [*k_idx.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_key_k_normed = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=0) + cache_key_k_pe = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=1) + cache_key_k_idx = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=2) + self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe, + cache_desc_k_idx) + self.cache_key = (cache_key_k_normed, cache_key_k_pe, + cache_key_k_idx) + try: + cache_k_normed = self.cache_manager.register_blocks_cache( + self.cache_desc[0], self.cache_addr[0], self.cache_key[0]) + cache_k_pe = self.cache_manager.register_blocks_cache( + self.cache_desc[1], self.cache_addr[1], self.cache_key[1]) + cache_k_idx = self.cache_manager.register_blocks_cache( + self.cache_desc[2], self.cache_addr[2], self.cache_key[2]) + self.cache = (cache_k_normed, cache_k_pe, cache_k_idx) + logger.info("LLMDataDistWorker: End of register Paged Cache.") + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) else: for cache_or_caches in kv_caches.values(): for cache in cache_or_caches: @@ -826,6 +881,38 @@ def _read_blocks( raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) + elif self.use_sfa: + remote_cache_key_k_normed = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=0) + remote_cache_key_k_pe = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=1) + remote_cache_key_k_idx = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=2) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key_k_normed, + self.cache[0], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_pe, + self.cache[1], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_idx, + self.cache[2], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) else: remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id) logger.info("Try pull blocks from remote server") diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 7faf1be2cf..c0fc1a65d9 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -30,6 +30,7 @@ from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -238,6 +239,7 @@ def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine, self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 + self.use_sfa = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() # TODO(jianzs): make this configurable @@ -349,8 +351,12 @@ def _transfer_kv_cache(self, req_meta: dict[str, Any]): src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): - block_len = (self.block_len[k % 2] - if self.use_mla else self.block_len[0]) + if self.use_mla: + block_len = (self.block_len[k % 2]) + elif self.use_sfa: + block_len = (self.block_len[k % 3]) + else: + block_len = (self.block_len[0]) for i, remote_block_id in enumerate(grouped_remote_block_ids): local_block_ids = grouped_local_block_ids[i] src = src_layer_base_addr + local_block_ids[0] * block_len @@ -567,6 +573,7 @@ class MooncakeConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + self.ascend_config = get_ascend_config() self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id logger.info("Initializing Mooncake Scheduler %s", engine_id) @@ -726,7 +733,7 @@ def get_finished_count(self) -> Optional[int]: assert "tp_size" in decode_parallel_config.keys() self._decode_tp_size = decode_parallel_config["tp_size"] - if self.vllm_config.model_config.use_mla: + if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: return self._decode_tp_size else: # TODO support mha and gqa @@ -847,7 +854,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO(tms): Find a more robust way to detect and handle MLA self.use_mla = first_kv_cache_tuple[0].size( - -1) != first_kv_cache_tuple[1].size(-1) + -1) != first_kv_cache_tuple[1].size(-1) and len( + first_kv_cache_tuple) == 2 + self.use_sfa = len(first_kv_cache_tuple) == 3 if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -861,6 +870,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", self.num_blocks, block_shape_norm, block_shape_pe) + elif self.use_sfa: + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe), + first_kv_cache[2].element_size() * math.prod(block_shape_k) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", + self.num_blocks, block_shape_norm, block_shape_pe, + block_shape_k) else: # [num_block, block_size, num_head, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -871,8 +895,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) - logger.info("Registering KV_Caches. use_mla: %s, shape %s", - self.use_mla, first_kv_cache.shape) + logger.info( + "Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s", + self.use_mla, self.use_sfa, first_kv_cache.shape) self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -884,9 +909,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) self._register(base_addr, region_len) + elif self.use_sfa: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 3] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) else: - cache_list = [cache_or_caches - ] if self.use_mla else cache_or_caches + cache_list = [ + cache_or_caches + ] if self.use_mla or self.use_sfa else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index dec0a1238a..2db4515436 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -162,6 +162,13 @@ # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), + # Timeout (in seconds) for delayed KVCache block release. In the prefill + # node, if a request is marked for delayed KV block release and the blocks + # are not freed within this timeout, they will be forcibly released. + "VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT": + lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)), + "VLLM_ASCEND_ENABLE_MLAPO": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 21ce47bd87..8577abe122 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -37,6 +37,10 @@ def register_model(): "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + ModelRegistry.register_model( + "DeepseekV32ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + ModelRegistry.register_model( "DeepSeekMTPModel", "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 988de33460..2333c3814c 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -60,6 +60,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.mla import AscendMLAModules +from vllm_ascend.models.layers.sfa import (AscendSFAModules, + AscendSparseFlashAttention, Indexer) from vllm_ascend.ops.fused_moe import AscendFusedMoE @@ -244,6 +246,180 @@ def forward( return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata) +class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + return_bias=False, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + return_bias=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + return_bias=False, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + return_bias=False, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + return_bias=False, + ) + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.dim: int = config.hidden_size # 7168 + # TODO(zzzzwwjj): wait transformers add these params + self.n_heads: int = 64 # 64 + self.head_dim: int = 128 # 128 + self.index_topk: int = 2048 # 2048 + self.indexer = Indexer( + config, + quant_config=quant_config, + dim=self.dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + index_topk=self.index_topk, + prefix=f"{prefix}.indexer", + ) + + sfa_modules = AscendSFAModules( + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + rotary_emb=self.rotary_emb, + indexer=self.indexer) + + self.sfa_attn = AscendSparseFlashAttention( + self.hidden_size, + self.enable_shared_expert_dp, + self.debug_layer_idx, + self.first_k_dense_replace, + self.tp_size, + sfa_modules, + self.num_local_heads, + self.scaling, + self.layers, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.q_lora_rank, + self.qk_nope_head_dim, + self.qk_head_dim, + self.v_head_dim, + cache_config, + quant_config, + prefix, + ) + self.prefix = prefix + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata) + + class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: @@ -253,6 +429,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config + ascend_config = get_ascend_config() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -268,7 +445,10 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.tp_rank = get_tp_group().rank_in_group # TODO: enable mla in vllm-ascend if model_config.use_mla: - attn_cls = CustomDeepseekV2MLAAttention + if ascend_config.use_sfa: + attn_cls = CustomDeepseekV2SFAAttention + else: + attn_cls = CustomDeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( diff --git a/vllm_ascend/models/layers/sfa.py b/vllm_ascend/models/layers/sfa.py new file mode 100644 index 0000000000..f68281cbc0 --- /dev/null +++ b/vllm_ascend/models/layers/sfa.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.mla import MultiHeadLatentAttention +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils import direct_register_custom_op + + +@dataclass +class AscendSFAModules: + q_a_proj: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: torch.nn.Module + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + o_proj: torch.nn.Module + rotary_emb: torch.nn.Module + indexer: torch.nn.Module + + +class AscendSparseFlashAttention(MultiHeadLatentAttention): + + def __init__( + self, + hidden_size: int, + enable_shared_expert_dp: bool, + debug_layer_idx: int, + first_k_dense_replace: int, + tp_size: int, + sfa_modules: AscendSFAModules, + num_local_heads: int, + scaling: float, + layers: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + q_lora_rank: Optional[int], + qk_nope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.enable_shared_expert_dp = enable_shared_expert_dp + self.debug_layer_idx = debug_layer_idx + self.first_k_dense_replace = first_k_dense_replace + self.tp_size = tp_size + self.num_local_heads = num_local_heads + self.layers = layers + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.q_lora_rank = q_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.prefix = prefix + + self.sfa_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + use_sfa=True, + # SFA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=sfa_modules.rotary_emb, + q_a_proj=sfa_modules.q_a_proj, + q_a_layernorm=sfa_modules.q_a_layernorm, + q_proj=sfa_modules.q_proj, + kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa, + kv_a_layernorm=sfa_modules.kv_a_layernorm, + kv_b_proj=sfa_modules.kv_b_proj, + o_proj=sfa_modules.o_proj, + indexer=sfa_modules.indexer) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # Simulate all gather to calculate output shape + num_tokens = num_tokens * self.tp_size + need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + # FIXME: This does not seem right, should make sure the buffer is fixed + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output, + self.prefix) + output = output.view(-1, output_shape[-1]) + return output + + +def sfa_forward( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if forward_context.attn_metadata: + attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name] + else: + attn_metadata = forward_context.attn_metadata + kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine] + self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, + need_gather_q_kv, output) + return + + +class Indexer(nn.Module): + + def __init__(self, + config, + dim: int = 7168, + n_heads: int = 64, + head_dim: int = 128, + index_topk: int = 2048, + q_lora_rank: int = 1536, + rope_head_dim: int = 64, + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = ""): + super().__init__() + + self.dim: int = dim # 7168 + self.n_heads: int = n_heads # 64 + self.head_dim: int = head_dim # 128 + self.rope_head_dim: int = rope_head_dim # 64 + self.index_topk: int = index_topk # 2048 + self.q_lora_rank: int = q_lora_rank # 1536 + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + return_bias=False, + ) + self.wk = ReplicatedLinear( + self.dim, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + return_bias=False, + ) + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.weights_proj", + return_bias=False, + ) + self.k_norm = nn.LayerNorm(self.head_dim) + self.softmax_scale = self.head_dim**-0.5 + + def forward(self): + return + + +def sfa_forward_fake( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="sfa_forward", + op_func=sfa_forward, + mutates_args=["output"], + fake_impl=sfa_forward_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index 77077c964f..89c74e7ecb 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -15,6 +15,10 @@ # limitations under the License. # +import vllm_ascend.patch.platform.patch_common.patch_config # noqa import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa +import vllm_ascend.patch.platform.patch_common.patch_transformers_utils # noqa +import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa +import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa diff --git a/vllm_ascend/patch/platform/patch_common/patch_config.py b/vllm_ascend/patch/platform/patch_common/patch_config.py new file mode 100644 index 0000000000..9b6f5c22f2 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_config.py @@ -0,0 +1,313 @@ +import ast + +import vllm.envs as envs +from transformers import PretrainedConfig +from vllm.config import ModelConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.logger import logger + + +# mypy: ignore-errors +@property +def is_deepseek_mla(self: ModelConfig): + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', + 'kimi_k2', 'longcat_flash', 'deepseek_v32'): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == 'eagle': + # if the model is an EAGLE module, check for the + # underlying architecture + return self.hf_text_config.model.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \ + and self.hf_text_config.kv_lora_rank is not None + return False + + +@staticmethod +def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type in ("deepseek_v3", "deepseek_v32"): + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + + if hf_config.architectures[0] == "MiMoForCausalLM": + hf_config.model_type = "mimo_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"] + }) + + if hf_config.architectures[0] == "Glm4MoeForCausalLM": + hf_config.model_type = "glm4_moe_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"] + }) + + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["ErnieMTPModel"] + }) + + if hf_config.model_type == "qwen3_next": + hf_config.model_type = "qwen3_next_mtp" + if hf_config.model_type == "qwen3_next_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["Qwen3NextMTP"] + }) + if hf_config.model_type == "longcat_flash": + hf_config.model_type = "longcat_flash_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["LongCatFlashMTPModel"] + }) + + return hf_config + + +def __post_init__(self): + + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + if (self.target_model_config + and self.target_model_config.hf_text_config.model_type + in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe", + "qwen3_next")): + # use the draft model from the same model: + self.model = self.target_model_config.model + # Align the quantization of draft model for cases such as + # --quantization fp8 with a bf16 checkpoint. + if not self.quantization: + self.quantization = self.target_model_config.quantization + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError("num_speculative_tokens was provided but without " + "speculative model.") + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if (self.prompt_lookup_min is None and self.prompt_lookup_max is None): + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + # TODO: Move this import to the top once `ModelConfig` + # lives in `vllm.config.model`. + from vllm.config import ModelConfig + self.draft_model_config = ModelConfig( + model=self.model, + runner="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config.trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + allowed_media_domains=self.target_model_config. + allowed_media_domains, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config.tokenizer_revision, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if self.method in ('eagle', 'eagle3'): + pass + # examples: + # yuhuili/EAGLE-LLaMA3-Instruct-8B + # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B + # AngelSlim/Qwen3-8B_eagle3 + elif "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif "eagle3" in self.draft_model_config.model.lower(): + self.method = "eagle3" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type + in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == + "qwen3_next_mtp"): + self.method = "qwen3_next_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Qwen3Next MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type + in ("longcat_flash_mtp")): + self.method = "longcat_flash_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "LongCat MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + else: + self.method = "draft_model" + raise NotImplementedError( + "Speculative decoding with draft model is not " + "supported yet. Please consider using other " + "speculative decoding methods such as ngram, medusa, " + "eagle, or deepseek_mtp.") + + # Replace hf_config for EAGLE draft_model + if self.method in ("eagle", "eagle3"): + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") + + from vllm.transformers_utils.configs import SpeculatorsConfig + from vllm.transformers_utils.configs.eagle import EAGLEConfig + + if isinstance(self.draft_model_config.hf_config, + (EAGLEConfig, SpeculatorsConfig)): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config, + method=self.method, + model_type="eagle") + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, "n_predict", + None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(self.num_speculative_tokens) + ]) + else: + # Sort the token tree breadth-first. + tree_choices = ast.literal_eval(self.speculative_token_tree) + self.speculative_token_tree = str( + sorted(tree_choices, key=lambda t: (len(t), t))) + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) + + +ModelConfig.is_deepseek_mla = is_deepseek_mla +SpeculativeConfig.__post_init__ = __post_init__ +SpeculativeConfig.hf_config_override = hf_config_override diff --git a/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py index d9ca8ff312..c90ec8e900 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py @@ -6,6 +6,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec +from vllm_ascend.ascend_config import get_ascend_config + @classmethod def verify_and_update_config(cls, vllm_config) -> None: @@ -22,6 +24,7 @@ def verify_and_update_config(cls, vllm_config) -> None: logger = init_logger(__name__) # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) + ascend_config = get_ascend_config() cache_config = vllm_config.cache_config model_config = vllm_config.model_config @@ -38,7 +41,7 @@ def verify_and_update_config(cls, vllm_config) -> None: num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, - use_mla=model_config.use_mla).page_size_bytes + use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, diff --git a/vllm_ascend/patch/platform/patch_common/patch_transformers_utils.py b/vllm_ascend/patch/platform/patch_common/patch_transformers_utils.py new file mode 100644 index 0000000000..55db19020e --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_transformers_utils.py @@ -0,0 +1,200 @@ +import vllm.transformers_utils.configs +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from vllm.transformers_utils import config + +logger = logging.get_logger(__name__) + + +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='noaux_tc', + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func='sigmoid', + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +vllm.transformers_utils.configs.__all__.append("DeepseekV3Config") +vllm.transformers_utils.configs.DeepseekV3Config = DeepseekV3Config +config._CONFIG_REGISTRY["deepseek_v32"] = "DeepseekV3Config" diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index baf53212ca..3d233c4798 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -20,6 +20,10 @@ if HAS_TRITON: import vllm_ascend.patch.worker.patch_common.patch_triton +# isort: off +import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa +import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa +import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py b/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py new file mode 100644 index 0000000000..6f4ad36427 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py @@ -0,0 +1,202 @@ +from typing import List, Optional + +import torch +import vllm +import vllm.envs as envs +from torch import nn +from vllm.attention import Attention, AttentionType, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.selector import backend_name_to_enum +from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import current_platform + +from vllm_ascend.utils import vllm_version_is + + +class AscendAttention(Attention, nn.Module, AttentionLayerBase): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, + use_sfa: bool = False, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + attn_backend: Optional[type[AttentionBackend]] = None, + **extra_impl_args, + ) -> None: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + """ + nn.Module.__init__(self) + AttentionLayerBase.__init__(self) + + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + is_attention_free = False + calculate_kv_scales = False + if num_kv_heads is None: + num_kv_heads = num_heads + assert num_heads % num_kv_heads == 0, \ + f"num_heads ({num_heads}) is not " \ + f"divisible by num_kv_heads ({num_kv_heads})" + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: Optional[float] = None + + self.use_mla = use_mla + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + self.has_sink = extra_impl_args.get("sinks") is not None + + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + if attn_backend is None: + if vllm_version_is("0.10.2"): + self.attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=self.has_sink) + else: + self.attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=self.has_sink) + else: + self.attn_backend = attn_backend + + impl_cls = self.attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **extra_impl_args) + self.backend = backend_name_to_enum(self.attn_backend.get_name()) + self.dtype = dtype + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = not current_platform.opaque_attention_op() + + self.use_output = self.attn_backend.accept_output_buffer + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + self.attn_type = attn_type + + if kv_sharing_target_layer_name is not None: + validate_kv_sharing_target( + prefix, + kv_sharing_target_layer_name, + compilation_config.static_forward_context, + ) + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] + + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + self.query_quant = None + + +vllm.attention.Attention = AscendAttention diff --git a/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py b/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py new file mode 100644 index 0000000000..793fef1859 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py @@ -0,0 +1,181 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# mypy: ignore-errors +from functools import cache +from typing import Optional + +import torch +import vllm +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.selector import (backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.platforms import _Backend, current_platform +from vllm.utils import resolve_obj_by_qualname + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.10.2"): + + def get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool = False, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=has_sink, + ) + + @cache + def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + use_v1: bool = False, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + from vllm.attention.backends.placeholder_attn import \ + PlaceholderAttentionBackend + return PlaceholderAttentionBackend + + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_env_var}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}" + ) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, + use_v1, use_mla, use_sfa, has_sink) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}" + ) + return resolve_obj_by_qualname(attention_cls) +else: + + def get_attn_backend( # type: ignore[misc] + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=has_sink, + ) + + @cache + def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool = False, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_env_var}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}" + ) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, + use_v1, use_mla, use_sfa, has_sink) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}" + ) + return resolve_obj_by_qualname(attention_cls) + + +vllm.attention.get_attn_backend = get_attn_backend +vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend diff --git a/vllm_ascend/patch/worker/patch_common/patch_attentionspec.py b/vllm_ascend/patch/worker/patch_common/patch_attentionspec.py new file mode 100644 index 0000000000..e1a5ac57d9 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_attentionspec.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass, fields +from typing import Optional + +import torch +import vllm +from typing_extensions import Self +from vllm.config import VllmConfig +from vllm.utils import cdiv, get_dtype_size +from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager, + spec_manager_map) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec + + +@dataclass(frozen=True) +class AttentionSpec(KVCacheSpec): + num_kv_heads: int + head_size: int + dtype: torch.dtype + use_mla: bool + use_sfa: bool + + @property + def page_size_bytes(self) -> int: + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + sfa_bytes = 128 * self.block_size * get_dtype_size( + self.dtype) if self.use_sfa else 0 + + return coef * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + sfa_bytes + + +vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec + + +@dataclass(frozen=True) +class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec): + sliding_window: Optional[int] = None + attention_chunk_size: Optional[int] = None + """ + When hybrid allocator is disabled and the model contains both full + attention layers and sliding window attention layers, sliding + window attention are regarded as full attention in KV cache manager + (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @classmethod + def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]: + if len(window_sizes) == 0: + return None + elif len(window_sizes) == 1: + return window_sizes.pop() + else: + raise ValueError( + "All attention layers in the same KV cache group must have the " + "same window size.") + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "FullAttentionSpec.") + + sliding_window = set(spec.sliding_window for spec in specs + if spec.sliding_window is not None) + attention_chunk_size = set(spec.attention_chunk_size for spec in specs + if spec.attention_chunk_size is not None) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + use_mla=specs[0].use_mla, + use_sfa=specs[0].use_sfa, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec.") + assert ( + (merged_spec.sliding_window is not None) + + (merged_spec.attention_chunk_size is not None) <= 1 + ), ("Model with both sliding window layers and chunked local attention " + "layers is not supported.") + return merged_spec + + +spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager}) + +vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 1f12c59eae..c1bf20308f 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -300,6 +300,7 @@ def get_attn_backend_cls(cls, block_size, use_v1, use_mla, + use_sfa, has_sink=False): if not use_v1: raise ValueError("vLLM Ascend does not support V0 engine.") @@ -307,21 +308,28 @@ def get_attn_backend_cls(cls, ascend_config = get_ascend_config() if use_mla and ascend_config.enable_shared_expert_dp: - return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" + if use_mla and not use_sfa: + return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" + if use_mla and use_sfa: + return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" use_torchair = ascend_config.torchair_graph_config.enabled # choose attention backend based on use_mla and use_torchair backend_map = { - (True, True): + (True, False, True): "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend", - (True, False): + (True, False, False): "vllm_ascend.attention.mla_v1.AscendMLABackend", - (False, True): + (False, False, True): "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend", - (False, False): - "vllm_ascend.attention.attention_v1.AscendAttentionBackend" + (False, False, False): + "vllm_ascend.attention.attention_v1.AscendAttentionBackend", + (True, True, False): + "vllm_ascend.attention.sfa_v1.AscendSFABackend", + (True, True, True): + "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend", } - return backend_map[(use_mla, use_torchair)] + return backend_map[(use_mla, use_sfa, use_torchair)] @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index ed4e8870cf..d0a0d507fa 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -603,10 +603,11 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): torch.npu.set_compile_mode(jit_compile=False) if not self.runner.use_cached_npu_graph: npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile(self.model, - dynamic=True, - fullgraph=True, - backend=npu_backend) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=not get_ascend_config().use_sfa, + fullgraph=True, + backend=npu_backend) return self.torchair_compiled_model else: # Generate a new forward proxy code object to prevent the invalidation of @@ -627,7 +628,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): self.torchair_compiled_models[ batch_size] = torchair.inference.cache_compile( self.model.__dict__[forward_proxy_name], - dynamic=True, + dynamic=not get_ascend_config().use_sfa, fullgraph=True, cache_dir=TORCHAIR_CACHE_DIR, config=config, diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index b1d8a8ac0b..8cf6e242bd 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -67,7 +67,9 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.models.layers.sfa import Indexer from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ @@ -435,6 +437,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + decoder_layer=None, ) -> None: nn.Module.__init__(self) self.hidden_size = hidden_size @@ -630,6 +633,225 @@ def forward( output_shape=output_shape) +class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + decoder_layer=None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + return_bias=False, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + return_bias=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + return_bias=False, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + return_bias=False, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + return_bias=False, + ) + if (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.multistream_overlap_shared_expert + or self.enable_shared_expert_dp)): + self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + else: + self.o_proj = TorchairDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.dim: int = config.hidden_size # 7168 + # TODO(zzzzwwjj): wait transformers add these params + self.n_heads: int = 64 # 64 + self.head_dim: int = 128 # 128 + self.index_topk: int = 2048 # 2048 + self.indexer = Indexer( + config, + quant_config=quant_config, + dim=self.dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + index_topk=self.index_topk, + prefix=f"{prefix}.indexer", + ) + + self.sfa_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + use_sfa=True, + # SFA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + indexer=self.indexer, + decoder_layer=decoder_layer, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + if not self.torchair_graph_enabled: + if forward_context.attn_metadata is not None and isinstance( + forward_context.attn_metadata, dict): + attn_metadata = next( + iter(forward_context.attn_metadata.values()), None) + else: + attn_metadata = forward_context.attn_metadata + if kv_cache is None: + kv_cache = self.sfa_attn.kv_cache[ + forward_context.virtual_engine] + + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + # if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # # Simulate all gather to calculate output shape + # num_tokens = num_tokens * self.tp_size + # need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace: + output_shape = hidden_states.shape + if self.enable_shared_expert_dp and ( + self.debug_layer_idx == self.first_k_dense_replace + or self.debug_layer_idx == self.layers): + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, + need_gather_q_kv, output) + output = output.view(-1, output_shape[-1]) + return output + + class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): def __init__( @@ -654,9 +876,16 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group ascend_config = get_ascend_config() + self.use_mla = False + self.use_sfa = False # TODO: enable mla in vllm-ascend if model_config.use_mla: - attn_cls = TorchairDeepseekV2MLAAttention + if ascend_config.use_sfa: + attn_cls = TorchairDeepseekV2SFAAttention + self.use_sfa = True + else: + attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment] + self.use_mla = True else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( @@ -675,6 +904,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + decoder_layer=self, ) if (config.n_routed_experts is not None @@ -715,21 +945,34 @@ def forward( replace_allreduce: bool = False, ) -> torch.Tensor: # Self Attention - if attn_metadata is not None and attn_metadata.num_decodes > 0: - mla_moe_communication = self.mla_moe_communication and replace_allreduce + if attn_metadata is not None: + decoding_condition_met = ( + not attn_metadata.is_prefill if self.use_sfa else + attn_metadata.num_decodes > 0 if self.use_mla else False) + mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce else: mla_moe_communication = False - if residual is None: + + forward_context = get_forward_context() + if (envs.VLLM_ASCEND_ENABLE_MLAPO + and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention) + and attn_metadata is not None + and not forward_context.with_prefill): + if residual is not None: + hidden_states = hidden_states + residual residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) else: - previous_hidden_states, previous_residual = hidden_states, residual - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - # Dispose hidden_states and residual from the previous layer - # to save npu memory because they're no longer used. - dispose_tensor(previous_hidden_states) - dispose_tensor(previous_residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + previous_hidden_states, previous_residual = hidden_states, residual + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index ebf61dff40..daf6b5d1e5 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -48,8 +48,8 @@ class NPUTorchairModelRunner(NPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): - ascend_config = get_ascend_config() - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.ascend_config = get_ascend_config() + self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp super().__init__(vllm_config, device) if self.speculative_config: self.actual_seq_lengths_q = list( @@ -66,10 +66,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.new_kv_cache_bytes = -1 self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore - self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph - self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes - self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes - if ascend_config.torchair_graph_config.graph_batch_sizes_init: + self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph + self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes + self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes + if self.ascend_config.torchair_graph_config.graph_batch_sizes_init: self.init_torchair_graph_batch_sizes() self.update_torchair_graph_batch_sizes() @@ -362,22 +362,23 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): communication_adaptation_310p() config = torchair.CompilerConfig() - if get_ascend_config().torchair_graph_config.mode: - config.mode = get_ascend_config().torchair_graph_config.mode + if self.ascend_config.torchair_graph_config.mode: + config.mode = self.ascend_config.torchair_graph_config.mode config.experimental_config.frozen_parameter = \ - get_ascend_config().torchair_graph_config.enable_frozen_parameter + self.ascend_config.torchair_graph_config.enable_frozen_parameter # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to # disable it on 300I Duo platform now. config.experimental_config.tiling_schedule_optimize = not is_310p() config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize + self.ascend_config.torchair_graph_config.enable_view_optimize torch.npu.set_compile_mode(jit_compile=False) if not self.use_cached_npu_graph: npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile(self.model, - dynamic=True, - fullgraph=True, - backend=npu_backend) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=not self.ascend_config.use_sfa, + fullgraph=True, + backend=npu_backend) return self.torchair_compiled_model else: # Generate a new forward proxy code object to prevent the invalidation of @@ -398,7 +399,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): self.torchair_compiled_models[ batch_size] = torchair.inference.cache_compile( self.model.__dict__[forward_proxy_name], - dynamic=True, + dynamic=not self.ascend_config.use_sfa, fullgraph=True, cache_dir=TORCHAIR_CACHE_DIR, config=config, diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py new file mode 100644 index 0000000000..8dc6a68862 --- /dev/null +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -0,0 +1,1330 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata +from vllm_ascend.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendSFATorchairBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_SFA_TORCHAIR" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendSFATorchairMetadata + + @staticmethod + def get_builder_cls(): + return AscendSFATorchairMetadataBuilder + + #NOTE: is that ok? + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["MLAAttentionImpl"]: + return AscendSFATorchairImpl + + +@dataclass +class AscendSFATorchairPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class TorchairChunkedContextMetadata: + # New for SFA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + + attn_mask: torch.Tensor + query_lens: list[int] # Check!! + seq_lens: list[int] # Check!! + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + sin: torch.Tensor + cos: torch.Tensor + chunked_context: Optional[TorchairChunkedContextMetadata] = None + + +@dataclass +class AscendSFATorchairDecodeMetadata: + # Input positions for rotrary embeddings since for SFA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: torch.Tensor + sin: torch.Tensor + cos: torch.Tensor + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendSFATorchairMetadata: + """Metadata for SFACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for SFA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendSFATorchairDecodeMetadata] = None + prefill: Optional[AscendSFATorchairPrefillMetadata] = None + enable_dbo_across_dp: bool = False + is_prefill: bool = False + is_decode: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendSFABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendSFATorchairMetadata"]: + """Split metadata for multi-stream with AscendSFATorchairMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendSFATorchairMetadata, + ) + + +M = TypeVar("M", bound=AscendSFATorchairMetadata) + + +class AscendSFATorchairMetadataBuilder: + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendSFATorchairMetadata] = None): + self.metadata_cls: Optional[AscendSFATorchairMetadata] = metadata_cls \ + if metadata_cls is not None else AscendSFATorchairMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 SFA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + # For torch air graph mode we treat spec decoding as decode. + if self.torchair_graph_enabled: + if num_tokens - num_spec_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + # For eager mode we treat spec decoding as chunked prefill. + else: + if num_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] + + def build_torchair_graph_dummy( + self, + common_attn_metadata: TorchairCommonAttentionMetadata, + ) -> AscendSFATorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req + seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) + seq_lens_list = [0] * num_reqs + input_positions = torch.zeros(num_tokens, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_tokens, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) + sin = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + cos = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': + attn_state = AscendAttentionState.SpecDecoding + num_decode_tokens = 2 + else: + attn_state = AscendAttentionState.DecodeOnly + num_decode_tokens = 1 + # cumsum here. + # actual_seq_lengths_q = torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_tokens]).to(torch.int32).npu() + # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).to(torch.int32).npu() + actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( + torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req ############## + decode_metadata = AscendSFATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=1, + attn_mask=common_attn_metadata.spec_attn_mask, + # actual_seq_lengths_q=torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_reqs]).to(torch.int32).npu(), + actual_seq_lengths_q=actual_seq_lengths_q, + # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu(), + sin=sin, + cos=cos, + ) + return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=0, + attn_mask=common_attn_metadata.attn_mask, + attn_state=attn_state, + prefill=None, + decode=decode_metadata, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_tables=block_table, + is_prefill=False, + is_decode=True) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendSFATorchairMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = common_attn_metadata.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[: + num_actual_tokens].to( + device, + non_blocking=True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + # check CPU operation here + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + is_prefill = False + is_decode = False + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[tokens_start:].max().item() + max_seq_lens = seq_lens[tokens_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + actual_query_lens = torch.tensor( + query_lens[tokens_start:], + dtype=torch.int32).npu() # int64->int32 + query_lens_prefill_sfa = torch.cumsum(actual_query_lens, + dim=0).to(torch.int32).npu() + seq_lens_prefill_sfa = torch.tensor(seq_lens, + dtype=torch.int32).npu() + prefill_metadata = AscendSFATorchairPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens_prefill_sfa, + seq_lens=seq_lens_prefill_sfa, + context_lens=seq_lens[tokens_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + is_prefill = True + + decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size + use_torchair_graph = graph_pad_size != -1 + if num_decodes > 0: + # Check here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( + torch.int32).npu() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() + # input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + num_token_pad_size = 0 + if use_torchair_graph and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + num_reqs_pad_size = 0 + if graph_pad_size != 0: + pad_value = 0 + num_token_pad_size = graph_pad_size - num_decode_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + padded_seq_lens = seq_lens.tolist( + ) + [pad_value] * num_reqs_pad_size + else: + padded_seq_lens = seq_lens.tolist() + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)).npu() + seq_lens_list = padded_seq_lens + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_reqs_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_reqs + num_reqs_pad_size, block_table) + position_padding = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + + # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).npu() + # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu() + actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( + torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req + # MTP ignored + # actual_seq_lengths_q = self.pad_actual_seq_len_q( + # num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + # common_attn_metadata) + else: + seq_lens_list = seq_lens.tolist() + # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) + batch_size = num_decode_tokens + num_token_pad_size + if actual_seq_lengths_q[-1] != batch_size \ + and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + actual_seq_lengths_q[-1] = batch_size + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + padded_token_num = input_positions.shape[0] + actual_seq_lengths_q = torch.arange( + 1, + (padded_token_num // common_attn_metadata.decode_token_per_req) + + 1).to(torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req + decode_metadata = AscendSFATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + is_decode = True + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + is_prefill=is_prefill, + is_decode=is_decode) + + def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q, common_attn_metadata): + """ + Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request + in order to meet the requirement of npu_fused_infer_attention_score. + + In Torchair scenario, the lengths of the queries must be padded to the same length. + And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). + + For example: + batch_size=36, num_reqs_pad_size=2, num_reqs=16 + By default, each request should have inference 2 token, which means actual_seq_lengths_q should be + [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. + + However, mtp torchair + PD scenario, the actual_seq_lengths_q may be + [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] + """ + FIA_SEQ_LEN_LIMIT = 16 + need_padding = num_reqs_pad_size != 0 and \ + len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ + common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + if need_padding: + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + start_val = actual_seq_lengths_q[-1] + end_val = padding_seq_len_q[-1] + + num_step = len(padding_seq_len_q) + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == len(padding_seq_len_q) + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + else: + actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + + # return actual_seq_lengths_q + return torch.Tensor(actual_seq_lengths_q).to(torch.int32).npu() + + +class PrefillSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + + +class DecodeSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + # nope_cache: Optional[torch.Tensor] = None + # rope_cache: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + bsz: Optional[int] = None + + +class AscendSFATorchairImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.indexer = kwargs['indexer'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_proj = kwargs.get('q_a_proj', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.decoder_layer = kwargs.get('decoder_layer', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + if self.q_a_proj is not None: + self.q_b_proj = self.q_proj + else: + self.q_b_proj = None + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.enable_prefetch + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + if ascend_config.torchair_graph_config.enabled: + self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[ + 0] + self.actual_seq_length = torch.arange(1, self.graph_batch_size + + 1).to(torch.int32).npu() + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 + self.prefill_mask = None + + # indexer param + self.dim = self.indexer.dim + self.n_heads: int = self.indexer.n_heads # 64 + self.head_dim: int = self.indexer.head_dim # 128 + self.index_topk: int = self.indexer.index_topk # 2048 + self.wq_b = self.indexer.wq_b + self.wk = self.indexer.wk + self.weights_proj = self.indexer.weights_proj + self.k_norm = self.indexer.k_norm + self.softmax_scale = self.indexer.softmax_scale + + # Adapt torch air graph mode with spec decoding. + speculative_config = vllm_config.speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 + + self.cp_size = 1 + + if self.q_a_proj is not None: + self.prefix = self.q_a_proj.prefix + else: + self.prefix = 0 + self.debug_layer_idx = int(self.prefix.split(".")[2]) + self.layers = vllm_config.model_config.hf_config.num_hidden_layers + self.first_k_dense_replace = vllm_config.model_config.hf_config.first_k_dense_replace + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO: + self._process_weights_for_fused_mlapo(act_dtype) + + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): + kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data.clone() + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data.clone()), + dim=-1) + wd_qkv = wd_qkv.t().contiguous() + wd_qkv = transdata(wd_qkv, + block_size=(16, 32)).unsqueeze(0).contiguous() + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) + + kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, + self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat( + (kv_a_proj_deq_scl, self.q_a_proj.deq_scale.clone()), + dim=-1).contiguous() + + kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias.clone() + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, + self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat( + (kv_a_proj_qt_bias, self.q_a_proj.quant_bias.clone()), + dim=-1).contiguous() + + wu_q = self.q_proj.weight.data.clone() + wu_q = wu_q.t().reshape(self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + -1) + wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) + wu_q = wu_q.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), + -1) + wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) + + qb_deq_scl = self.q_proj.deq_scale.data.clone() + qb_deq_scl = qb_deq_scl.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) + self.qb_deq_scl = qb_deq_scl.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + qb_qt_bias = self.q_proj.quant_bias.data.clone() + qb_qt_bias = qb_qt_bias.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) + self.qb_qt_bias = qb_qt_bias.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + self.gamma0 = self.decoder_layer.input_layernorm.weight.data + self.beta0 = self.decoder_layer.input_layernorm.bias.data + self.gamma1 = self.q_a_layernorm.weight.data + self.beta1 = self.q_a_layernorm.bias.data + self.gamma2 = self.kv_a_layernorm.weight.data + self.quant_scale0 = self.q_a_proj.input_scale.data + self.quant_offset0 = self.q_a_proj.input_offset.data + self.quant_scale1 = self.q_proj.input_scale.data + self.quant_offset1 = self.q_proj.input_offset.data + + def _sfa_decode_preprocess(self, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv): + bsz = hidden_states.shape[0] + cos_shape = attn_metadata.decode.cos.shape + cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) + sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) + ctkv_scale = torch.tensor([1], + dtype=hidden_states.dtype, + device=hidden_states.device) + q_nope_scale = torch.tensor([1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + decode_q_nope, _, decode_q_pe, _ = torch_npu.npu_mla_process( + hidden_states, + self.gamma0, + self.beta0, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.kv_b_proj_w_k, + kv_cache[0], + kv_cache[1], + attn_metadata.slot_mapping.flatten(), + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=ctkv_scale, + q_nope_scale=q_nope_scale, + cache_mode_opt="krope_ctkv", + quant_mode_opt="per_tensor_quant_asymm", + ) + decode_k_nope = kv_cache[0] + decode_k_pe = kv_cache[1] + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + hidden_states = self.decoder_layer.input_layernorm(hidden_states) + decode_kq = self.q_a_proj(hidden_states) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + + topk_indices = self.indexer_select(hidden_states, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache, + is_prefill=False) + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_pe) + decode_preprocess_res = DecodeSFAPreprocessResult( + q_nope=decode_q_nope, + q_pe=decode_q_pe, + topk_indices=topk_indices, + query_states=query_states, + key_states=key_states, + bsz=bsz, + ) + return decode_preprocess_res + + def forward( + self, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output + + if attn_metadata.prefill is not None: + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + bsz = 1 + + hidden_states_prefill = hidden_states + prefill_slot_mapping = attn_metadata.slot_mapping + prefill_kq = self.q_a_proj(hidden_states_prefill) # q down + prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm + prefill_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_prefill) # c_kv + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + prefill_kv_no_split = get_tp_group().all_gather( + prefill_kv_no_split, + 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + # prefill_q_c = q_c[ + # num_decode_tokens:num_actual_tokens] + + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + # prefill_kv_no_split = kv_no_split[ + # num_decode_tokens:num_actual_tokens] + # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] + prefill_qr = prefill_q_c + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + prefill_qr = get_tp_group().all_gather( + prefill_qr, + 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + + prefill_q = self.q_b_proj(prefill_qr) + prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) + prefill_q_nope, prefill_q_pe = torch.split( + prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + prefill_q_nope = prefill_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + prefill_q_nope = (torch.matmul(prefill_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(-1, self.num_heads, + self.kv_lora_rank)) + prefill_q_pe = prefill_q_pe.unsqueeze(2) + + # stream2 kv + + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + cos_q, sin_q = cos, sin + + prefill_q_pe = torch_npu.npu_interleave_rope( + prefill_q_pe, cos_q, sin_q) # BNSD + prefill_q_pe = prefill_q_pe.squeeze(2) #BSH + # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? + + prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) + prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + prefill_latent_cache.view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), + self.kv_a_layernorm.weight, + cos.view(-1, 1, 1, self.qk_rope_head_dim), + sin.view(-1, 1, 1, self.qk_rope_head_dim), + prefill_slot_mapping.to(torch.int64), + rope_cache, + nope_cache, + k_rope_scale=None, + c_kv_scale=None, + k_rope_offset=None, + c_kv_offset=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA") + + topk_indices = self.indexer_select(x=hidden_states_prefill, + qr=prefill_qr, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + is_prefill=True) + query_states = (prefill_q_nope, prefill_q_pe) + key_states = (prefill_k_nope, prefill_k_pe) + q_nope, q_pe = query_states + k_nope, k_rope = key_states + prefill_metadata = attn_metadata.prefill + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=prefill_metadata.block_table, + actual_seq_lengths_query=prefill_metadata.query_lens, + actual_seq_lengths_kv=prefill_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul( + slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + # o_proj_input[num_decode_tokens:] = attn_output + output[...] = self.o_proj(attn_output, is_force_scatter=True) + return output + + elif attn_metadata.decode is not None: + if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO: + prep_res = self._sfa_decode_preprocess(hidden_states, kv_cache, + attn_metadata, + need_gather_q_kv) + q_nope, q_pe = prep_res.query_states + k_nope, k_rope = prep_res.key_states + topk_indices = prep_res.topk_indices + else: + q_len = 1 + hidden_states_decode = hidden_states + decode_kq = self.q_a_proj(hidden_states_decode) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + decode_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_decode) # c_kv + # self.actual_seq_length = torch.arange(1,self.graph_batch_size+1).to(torch.int32).npu() + + # decode_q_c = q_c[:num_decode_tokens] + decode_slot_mapping = attn_metadata.slot_mapping + + decode_q = self.q_b_proj(decode_q_c) + bsz, _ = decode_q.shape + decode_q = decode_q.view(bsz, self.num_heads, 1, + self.qk_head_dim) # [16, 16, 1, 192] + decode_q_nope, decode_q_pe = torch.split( + decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) # [..., 128/64] + decode_q_nope = decode_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + decode_q_nope = (torch.matmul( + decode_q_nope, self.kv_b_proj_w_k).transpose(1, 0).view( + bsz, q_len, self.num_heads, self.kv_lora_rank)) + + # stream2 kv + key_cache = kv_cache[0] + value_cache = kv_cache[1] + cos = attn_metadata.decode.cos # [16, 1, 1, 64] + sin = attn_metadata.decode.sin + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze( + 1) + decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + decode_kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + decode_slot_mapping.to(torch.int64), + value_cache, + key_cache, + c_kv_scale=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode='PA') # adapter NZ + # nz_block_size = 16 + # KVCACHE_NZ_DIM = 16 + # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) + # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) + decode_q_pe = torch_npu.npu_interleave_rope( + decode_q_pe, cos, sin) # BNSD + + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + topk_indices = self.indexer_select(hidden_states_decode, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache, + is_prefill=False) + + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_rope) + q_nope, q_pe = query_states + k_nope, k_rope = key_states + + decode_metadata = attn_metadata.decode + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=attn_metadata.decode.block_table, + actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=decode_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.squeeze(1) + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul( + slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + output[...] = self.o_proj(attn_output) + return output + + def mla_epilog(self, + attn_output: torch.Tensor = None, + absorb: bool = False): + # TODO: + attn_output = self.o_proj(attn_output) + return attn_output + + def indexer_select( + self, + x: torch.Tensor, + qr: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + is_prefill: bool = True, + ): + if attn_metadata.prefill is not None: + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + elif attn_metadata.decode is not None: + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + # q process in new stream + q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + k_proj = get_tp_group().all_gather( + k_proj, 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + k = self.k_norm(k_proj).unsqueeze(1) + k_pe, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + + weights = self.weights_proj(x) + if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + weights = get_tp_group().all_gather( + weights, 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + actual_seq_lengths_query = None + actual_seq_lengths_key = None + block_table = None + if attn_metadata.prefill is not None: + actual_seq_lengths_query = attn_metadata.prefill.query_lens + actual_seq_lengths_key = attn_metadata.prefill.seq_lens + + block_table = attn_metadata.prefill.block_table + elif attn_metadata.decode is not None: + actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q + actual_seq_lengths_key = attn_metadata.decode.seq_lens.to( + torch.int32) + + block_table = attn_metadata.decode.block_table + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3) + return topk_indices + + +def round_up(val: int, align: int) -> int: + if align == 0: + return 0 + return -(val // -align) * align + + +def trans_rope_weight(weight, rope_dim): + weight_1 = weight[..., -rope_dim::2, :].contiguous() + weight_2 = weight[..., -rope_dim + 1::2, :].contiguous() + weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2) + + return weight.contiguous() + + +def transdata(nd_mat, block_size: tuple = (16, 16)): + r = round_up(nd_mat.shape[0], block_size[0]) + c = round_up(nd_mat.shape[1], block_size[1]) + r_pad = r - nd_mat.shape[0] + c_pad = c - nd_mat.shape[1] + nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad))) + nz_mat = torch.permute( + torch.reshape( + nd_mat, + (r // block_size[0], block_size[0], c // block_size[1], + block_size[1]), + ), + [2, 0, 1, 3], + ) + nz_mat = torch.reshape( + nz_mat, + (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) + return nz_mat diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 56b8c710d1..668a7e7b5f 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -165,6 +165,11 @@ def register_torchair_model(): "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" ) + ModelRegistry.register_model( + "DeepseekV32ForCausalLM", + "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" + ) + ModelRegistry.register_model( "Qwen2ForCausalLM", "vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0984e2bf63..9281dd70bf 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -285,8 +285,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.intermediate_tensors: Optional[IntermediateTensors] = None self.runner_only_attn_layers: set[str] = set() - ascend_config = get_ascend_config() - if ascend_config.ascend_scheduler_config.enabled: + self.ascend_config = get_ascend_config() + if self.ascend_config.ascend_scheduler_config.enabled: self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled else: self.chunked_prefill_enabled = True @@ -298,6 +298,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.cache_config.cache_dtype] # use_hybrid_blocks: if hybrid blocks is used. self.use_hybrid_blocks: bool = False + self.need_accepted_tokens: bool = False self.is_multimodal_model = self.model_config.is_multimodal_model self.is_pooling_model = self.model_config.pooler_config is not None @@ -315,7 +316,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.block_size, self.model_config.is_attention_free, use_mla=self.model_config.use_mla, - ) + use_sfa=self.ascend_config.use_sfa) else: self.attn_backend = get_attn_backend( 0, @@ -323,7 +324,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): None, self.block_size, use_mla=self.model_config.use_mla, - ) + use_sfa=self.ascend_config.use_sfa) if torch.version.cann.startswith("8.3"): self.attn_mask_builder = AttentionMaskBuilder( self.scheduler_config.max_num_batched_tokens, self.dtype, @@ -457,7 +458,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.bool, device=self.device, ) - self.dynamic_eplb = ascend_config.dynamic_eplb + self.dynamic_eplb = self.ascend_config.dynamic_eplb if self.dynamic_eplb: self.is_eplb_warmuped = False self.eplb_loader = D2DExpertWeightLoader() @@ -890,15 +891,16 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: # Chunk Prefill situation. - if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: + if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa: if torch.version.cann.startswith("8.3"): return self.attn_mask_builder.get_splitfuse_attn_mask() else: return self.attn_mask_builder.get_splitfuse_attn_mask( seq_lens, position, self.dtype, self.device) + # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens, default=0) + max_seq_len = max(seq_lens.max().item(), 0) return self.attn_mask_builder.get_attn_mask( max_seq_len, self.dtype, self.device) # Prefill with cache hit. @@ -1252,7 +1254,7 @@ def _prepare_inputs( req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) + max_num_scheduled_tokens = num_scheduled_tokens.max() num_valid_tokens = np.array([ num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) @@ -1376,8 +1378,6 @@ def _prepare_inputs( positions_cpu = self.positions_cpu[:num_input_tokens] positions = self.positions[:num_input_tokens] seq_lens_cpu = self.seq_lens_cpu[:num_reqs] - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) @@ -1477,7 +1477,7 @@ def _prepare_inputs( num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) spec_decode_common_attn_metadata = None - if use_spec_decode: + if use_spec_decode and self.need_accepted_tokens: self.num_accepted_tokens.np[:num_reqs] = ( self.input_batch.num_accepted_tokens_cpu[:num_reqs]) self.num_accepted_tokens.np[num_reqs:].fill(1) @@ -1550,7 +1550,7 @@ def _prepare_inputs( model=self.model, **extra_attn_metadata_args) - if self.vllm_config.model_config.use_mla: + if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: attn_metadata_i.num_input_tokens = num_input_tokens for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -2060,7 +2060,8 @@ def execute_model( sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids - self._update_states_after_model_execute(output_token_ids) + if self.need_accepted_tokens: + self._update_states_after_model_execute(output_token_ids) discard_sampled_tokens_req_indices: list[int] = [] # TODO(woosuk): The following loop can be slow since it iterates over @@ -2683,10 +2684,26 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.kv_cache_config = kv_cache_config self.initialize_attn_backend(kv_cache_config) self.use_hybrid_blocks = (len(self.attn_groups) > 1) + # NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`. + if vllm_version_is("0.10.2"): + self.need_accepted_tokens = any([ + isinstance( + self.kv_cache_config.kv_cache_groups[0].kv_cache_spec, + MambaSpec) for attn_group in self.attn_groups + ]) + else: + self.need_accepted_tokens = any([ + isinstance(attn_group[0].kv_cache_spec, MambaSpec) + for attn_group in self.attn_groups + ]) + self.may_reinitialize_input_batch(kv_cache_config) - if self.model_config.is_deepseek_mla: - kv_caches = self.initialize_kv_cache_tensors_deepseek( + if self.ascend_config.is_deepseek_sfa: + kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( + kv_cache_config) + elif self.model_config.is_deepseek_mla: + kv_caches = self.initialize_kv_cache_tensors_deepseek_mla( kv_cache_config) else: kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -2701,7 +2718,116 @@ def _align_memory(self, tensor: torch.Tensor, offset = (aligned_addr - data_ptr) // tensor.element_size() return tensor[int(offset):] - def initialize_kv_cache_tensors_deepseek( + def initialize_kv_cache_tensors_deepseek_sfa( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "NPU.") + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator_dispatcher(): + if vllm_version_is("0.10.2"): + kv_cache_spec, group = group + else: + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + tensor_size = kv_cache_sizes[layer_name] + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr( + attn_backend, "get_supported_block_size" + ) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + alignment = 2 * 1024 * 1024 + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_cache_shape = (num_blocks, block_size, num_kv_heads, + nope_dim) + rope_cache_shape = (num_blocks, block_size, num_kv_heads, + rope_dim) + #### k cache + # TODO(zzzzwwjj): wait transformers add these params + k_cache_shape = (num_blocks, block_size, 1, 128) + if self.vllm_config.kv_transfer_config is None: + # For no disaggregate pd scenario, allocate kv cache in normal way + rope_cache = torch.zeros(rope_cache_shape, + dtype=dtype, + device=self.device) + nope_cache = torch.zeros(nope_cache_shape, + dtype=dtype, + device=self.device) + rope_cache = self._convert_torch_format(rope_cache) + nope_cache = self._convert_torch_format(nope_cache) + + #### k cache + k_cache = torch.zeros(k_cache_shape, + dtype=dtype, + device=self.device) + k_cache = self._convert_torch_format(k_cache) + else: + + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + #### k cache + # TODO(zzzzwwjj): wait transformers add these params + k_allocate_shape = num_blocks * block_size * 1 * 128 + k_allocate_shape_alignment = k_allocate_shape + alignment + k_cache = torch.zeros(k_allocate_shape_alignment, + dtype=dtype, + device=self.device) + + nope_cache = self._align_memory( + nope_cache, + alignment)[:nope_allocate_shape].view(nope_cache_shape) + rope_cache = self._align_memory( + rope_cache, + alignment)[:rope_allocate_shape].view(rope_cache_shape) + k_cache = self._align_memory( + k_cache, + alignment)[:k_allocate_shape].view(k_cache_shape) + + kv_caches[layer_name] = (nope_cache, rope_cache, k_cache) + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + + def initialize_kv_cache_tensors_deepseek_mla( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: @@ -3217,6 +3343,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + use_sfa = self.ascend_config.use_sfa kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): @@ -3243,7 +3370,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) + use_mla=use_mla, + use_sfa=use_sfa) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index c1fc800da8..dc82ece6cc 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -43,7 +43,7 @@ from vllm.v1.worker.worker_base import WorkerBase import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import init_ascend_config +from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform @@ -88,6 +88,17 @@ def __init__( # init ascend config and soc version init_ascend_config(vllm_config) init_ascend_soc_version() + if get_ascend_config().use_sfa: + # Direct import instead of using try_register_lib to ensure proper error handling when + # custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments) + # yapf: disable + import custom_ops # type: ignore # noqa + + # yapf: enable + logger.info( + "custom_ops module loaded successfully. Custom operators like " + "torch.ops.custom.npu_sparse_flash_attention are now available." + ) super().__init__(vllm_config=vllm_config, local_rank=local_rank,