diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 694adabbab..188ba9f34d 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,14 +24,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState -from vllm.config import get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch -from vllm_ascend.attention.utils import \ - AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import get_graph_params @@ -156,39 +155,49 @@ def split_metadata_for_multistream( class AscendAttentionMetadataBuilder: - def __init__(self, runner): + def __init__(self, vllm_config: VllmConfig, device: torch.device, runner): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device self.runner = runner def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, - num_reqs, - num_actual_tokens, - max_query_len, - common_attn_metadata: CommonAttentionMetadata, - enable_dbo_across_dp: bool = False, - is_only_prefill: bool = False, - *args, - **kwargs): - - block_table = self.runner.input_batch.block_table[0].get_device_tensor( - ) - block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( - block_table[:num_reqs]) - - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens + def build( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + ): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] + + block_table = common_attn_metadata.block_table_tensor + block_table[:num_reqs, :common_attn_metadata. + max_num_blocks_per_req] = (block_table[:num_reqs]) + + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] # TODO: Refactor these two param to common metadata in runners, # preparing for the hybrid KV groups feature - query_lens = common_attn_metadata.query_lens or self.runner.query_lens + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # Since FIA for GQA is not active now, we temporarily silence it seq_lens_list = common_attn_metadata.seq_lens_list - slot_mapping = self.runner.slot_mapping[:num_actual_tokens] - attn_mask = self.runner.attn_mask - attn_state = self.runner.attn_state + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + self.device, + non_blocking= + True) + attn_mask = common_attn_metadata.attn_mask + attn_state = common_attn_metadata.attn_state + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] + query_start_loc = query_start_loc_cpu.to(self.device, + non_blocking=True) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -197,34 +206,49 @@ def build(self, query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens_list, - max_query_len=max_query_len, + max_query_len=common_attn_metadata.max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=enable_dbo_across_dp, - is_only_prefill=is_only_prefill) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + is_only_prefill=common_attn_metadata.is_only_prefill) return attn_metadata def build_dummy_metadata(self, num_actual_tokens, num_reqs, num_scheduled_tokens, attn_state): if attn_state == AscendAttentionState.DecodeOnly: # NOTE: We only need to pay attention to seq_lens_list and block_table here - common_attn_metadata = CommonAttentionMetadata( - seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2)) - block_table = self.runner.input_batch.block_table[0].block_table block_table[:num_reqs, 0] = torch.arange(1, num_reqs + 1, device=block_table.device, dtype=block_table.dtype) + block_table = self.runner.input_batch.block_table[ + 0].get_device_tensor() + block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + block_table[:num_reqs]) - attn_metadata = self.build( - num_reqs=num_reqs, + query_start_loc = None + seq_lens = torch.empty_like(self.runner.seq_lens_cpu).fill_(2) + query_lens = self.runner.query_lens + seq_lens_list = None + + slot_mapping = self.runner.slot_mapping[:num_actual_tokens] + attn_mask = self.runner.attn_mask + + attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, + block_tables=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, max_query_len=num_scheduled_tokens.max(), - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state, + enable_dbo_across_dp=False, + is_only_prefill=False) else: raise NotImplementedError( "Currently we only support building dummy metadata for DecodeOnly state" diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 42dc11bb53..201250756a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -8,7 +8,7 @@ AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -16,8 +16,8 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import \ - AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn @@ -173,20 +173,26 @@ class AscendMLAMetadataBuilder: # _attn_mask_builder = None def __init__(self, + vllm_config: VllmConfig, + device: torch.device, runner, metadata_cls: Optional[AscendMLAMetadata] = None): self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ if metadata_cls is not None else AscendMLAMetadata # type: ignore self.runner = runner - scheduler_config = runner.scheduler_config - model_config = runner.model_config - self.block_size = runner.block_size - self.chunked_prefill_enabled = runner.chunked_prefill_enabled + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request - max(8 * model_config.max_model_len, + max(8 * self.model_config.max_model_len, 4 * scheduler_config.max_num_seqs * self.block_size), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, @@ -201,17 +207,20 @@ def __init__(self, scheduler_config.max_num_seqs * self.block_size self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, - model_config.get_head_size()), - dtype=model_config.dtype, - device=runner.device, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None self.prefill_attn_mask = torch.triu( - torch.ones(512, 512, device=runner.device, dtype=runner.dtype), + torch.ones(512, + 512, + device=self.device, + dtype=self.model_config.dtype), 1) # 512: mask only support 512 def reorder_batch(self, input_batch: "InputBatch", @@ -224,8 +233,6 @@ def reorder_batch(self, input_batch: "InputBatch", # better naming here) decodes = [] prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] @@ -235,18 +242,14 @@ def reorder_batch(self, input_batch: "InputBatch", if self.torchair_graph_enabled: if num_tokens - num_spec_tokens == 1: decodes.append(i) - num_decode_tokens += num_tokens else: prefills.append(i) - num_prefill_tokens += num_tokens # For eager mode we treat spec decoding as chunked prefill. else: if num_tokens == 1: decodes.append(i) - num_decode_tokens += num_tokens else: prefills.append(i) - num_prefill_tokens += num_tokens # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are @@ -274,50 +277,22 @@ def reorder_batch(self, input_batch: "InputBatch", else: break - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - return modified_batch def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: - - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}" - - if isinstance(self.runner.graph_block_tables, np.ndarray): - graph_block_tables = torch.zeros((max_batch_size, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - else: - graph_block_tables = self.runner.graph_block_tables.to( - device=block_tables.device, dtype=block_tables.dtype) - num_blocks = block_tables.size(1) - if num_blocks <= max_blocks: - graph_block_tables[:num_seqs, : - num_blocks] = block_tables[:num_seqs, : - num_blocks] - else: - graph_block_tables[:num_seqs, : - max_blocks] = block_tables[:num_seqs, : - max_blocks] - - return graph_block_tables[:num_seqs, :max_blocks] + num_blocks = min(num_blocks, self.max_blocks) + return block_tables[:num_seqs, :num_blocks] def build_torchair_graph_dummy( self, num_reqs: int, num_actual_tokens: int, ) -> AscendMLAMetadata: - device = self.runner.device - _, max_blocks = self.runner.graph_block_tables.shape - block_table = torch.zeros((num_reqs, max_blocks), + device = self.device + # does block_table really need to shape of (num_reqs, self.max_blocks) + block_table = torch.zeros((num_reqs, self.max_blocks), dtype=torch.int32, device=device) block_table = self._get_graph_runner_block_tables( @@ -336,8 +311,8 @@ def build_torchair_graph_dummy( -1, dtype=torch.int32, device=device) - if self.runner.speculative_config is not None and\ - self.runner.speculative_config.method == 'deepseek_mtp': + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': attn_state = AscendAttentionState.SpecDecoding num_decode_tokens = 2 else: @@ -347,13 +322,13 @@ def build_torchair_graph_dummy( 1, 1, self.rope_dim, - dtype=self.runner.dtype, + dtype=self.model_config.dtype, device=device) cos = torch.ones(num_tokens, 1, 1, self.rope_dim, - dtype=self.runner.dtype, + dtype=self.model_config.dtype, device=device) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -384,79 +359,110 @@ def build_torchair_graph_dummy( def build( self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, - common_attn_metadata: CommonAttentionMetadata, - common_prefix_len: Optional[int] = None, + common_attn_metadata: AscendCommonAttentionMetadata, num_token_pad_size: int = -1, num_reqs_pad_size: int = 0, enable_dbo_across_dp: bool = False, *args, **kwargs, ) -> AscendMLAMetadata: - assert self._num_decodes + self._num_prefills == num_reqs + + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - - block_table = (self.runner.input_batch.block_table[0]. - get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + block_table = common_attn_metadata.block_table_tensor[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_tokens].to( device, non_blocking=True) - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( + + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens_cpu + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - + query_seq_lens_cpu) + + if self.runner.torchair_graph_enabled and self.runner.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = self.runner.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + input_positions = self.runner.positions_cpu[:num_tokens].to( device, non_blocking=True).long() + input_positions = common_attn_metadata.positions[:num_tokens].long() - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[: - num_reqs] - seq_lens = seq_lens_cpu - max_query_len = query_lens.max().item() max_seq_lens = seq_lens.max().item() - query_start_loc = common_attn_metadata.query_start_loc if self.cos_cache is None: self.cos_cache = self.runner.get_model( ).model.layers[0].self_attn.rotary_emb.cos_cached self.sin_cache = self.runner.get_model( ).model.layers[0].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.runner.dtype: # type: ignore + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore self.cos_cache = self.cos_cache.to( # type: ignore - self.runner.dtype) # type: ignore + self.model_config.dtype) # type: ignore self.sin_cache = self.sin_cache.to( # type: ignore - self.runner.dtype) # type: ignore + self.model_config.dtype) # type: ignore prefill_metadata = None - chunked_context_metadata = None - if self._num_prefills > 0: - reqs_start = self._num_decodes # prefill_start - tokens_start = self._num_decode_tokens - max_query_len = query_lens[tokens_start:].max().item() - max_seq_lens = seq_lens[tokens_start:].max().item() - query_start_loc = common_attn_metadata.query_start_loc - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + if num_prefills > 0: + reqs_start = num_decodes # prefill_start - context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[ - reqs_start:num_reqs] + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + tokens_start = num_decode_tokens + chunked_context_metadata = None + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) + # align max_context_chunk to block_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to block_size max_context_chunk = round_down(max_context_chunk, self.block_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) - chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) \ + * max_context_chunk chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, + num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(chunk_seq_lens, @@ -473,6 +479,8 @@ def build( chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) + assert max(chunked_context_metadata.max_seq_lens) <= \ + self.chunked_prefill_workspace_size prefill_input_positions = input_positions[tokens_start:] cos = self.cos_cache[ prefill_input_positions].unsqueeze( # type: ignore @@ -482,8 +490,8 @@ def build( 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.prefill_attn_mask, - query_lens=query_lens[tokens_start:], - seq_lens=seq_lens, + query_lens=query_seq_lens_cpu[tokens_start:], + seq_lens=seq_lens[reqs_start:], context_lens=seq_lens[tokens_start:], input_positions=prefill_input_positions, block_table=block_table[reqs_start:, ...], @@ -497,12 +505,12 @@ def build( decode_metadata = None use_torchair_graph = num_token_pad_size != -1 - if self._num_decodes > 0: - actual_seq_lengths_q = query_start_loc[1:].tolist() - max_seq_lens = seq_lens[:self._num_decodes].max().item() - seq_lens = seq_lens[:self._num_decode_tokens] - input_positions = input_positions[:self._num_decode_tokens] - block_table = block_table[:self._num_decode_tokens, ...] + if num_decodes > 0: + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes] + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] if use_torchair_graph and self.runner.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding @@ -517,11 +525,11 @@ def build( seq_lens = torch.from_numpy( np.array(padded_seq_lens).astype(np.int32)) seq_lens_list = padded_seq_lens - padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=slot_mapping.dtype, - device=slot_mapping.device) - slot_mapping = torch.cat([slot_mapping, padding]) + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) block_table_padding = torch.zeros( (num_reqs_pad_size, ) + block_table.shape[1:], dtype=block_table.dtype, @@ -530,17 +538,19 @@ def build( dim=0) block_table = self._get_graph_runner_block_tables( num_reqs + num_reqs_pad_size, block_table) - padding_0 = torch.zeros(num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat([input_positions, padding_0]) + position_padding = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) actual_seq_lengths_q = query_start_loc[1:].tolist( ) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs + num_reqs_pad_size] else: seq_lens_list = seq_lens.tolist() # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) - batch_size = slot_mapping.size(0) + num_token_pad_size = max(0, num_token_pad_size) + batch_size = num_decode_tokens + num_token_pad_size if actual_seq_lengths_q[-1] != batch_size \ and self.runner.attn_state == AscendAttentionState.SpecDecoding: actual_seq_lengths_q[-1] = batch_size @@ -562,13 +572,13 @@ def build( cos=cos) return self.metadata_cls( # type: ignore - num_actual_tokens=num_actual_tokens, - query_lens=query_lens.tolist(), + num_actual_tokens=num_tokens, + query_lens=query_seq_lens_cpu.tolist(), slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, attn_mask=self.runner.attn_mask, attn_state=self.runner.attn_state, prefill=prefill_metadata, @@ -576,7 +586,7 @@ def build( query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - enable_dbo_across_dp=enable_dbo_across_dp, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) @@ -945,8 +955,11 @@ def _forward_decode( self.qk_rope_head_dim) input_layout = "BNSD" - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill + ]: + assert num_tokens % (1 + self.spec_token_num) == 0 input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index c2b7bc156a..aa675ca686 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import torch @@ -7,17 +7,86 @@ @dataclass class AscendCommonAttentionMetadata: """ - Attention metadata attributes that can be shared by layers in different KV - cache groups and thus having different block table. + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. """ - - query_start_loc: torch.Tensor = None + num_reqs: int + """Number of requests""" + num_actual_tokens: int + """Total number of tokens in batch""" + max_query_len: int + """Longest query in batch""" + decode_token_per_req: int + max_num_blocks_per_req: int + attn_state: Any + query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" - seq_lens: Optional[torch.Tensor] = None + + seq_lens: torch.Tensor = None + seq_lens_cpu: torch.Tensor = None """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + + actual_seq_lengths_q: Optional[list[int]] = None + + block_table_tensor: torch.Tensor = None + slot_mapping_cpu: torch.Tensor = None + + positions: torch.Tensor = None + + attn_mask: torch.Tensor = None + spec_attn_mask: torch.Tensor = None + + enable_dbo_across_dp: bool = False + graph_pad_size: int = -1 query_lens: Optional[torch.Tensor] = None + is_only_prefill: bool = False """(batch_size,), the length of each request including only the newly scheduled tokens""" seq_lens_list: Optional[list] = None """(num_input_tokens,), note that this is specifically for FIA kernel""" + + +def split_decodes_and_prefills( + common_attn_metadata: AscendCommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: AscendCommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, num_tokens, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, num_tokens, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] >= decode_threshold) + assert torch.all(query_lens[:first_prefill] <= decode_threshold) + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item() + num_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d7301fd280..4930d80a81 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -77,8 +77,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import \ - AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.distributed.utils import is_lmhead_tp from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.eplb_updator import EplbUpdator @@ -195,7 +194,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): "Non-Attention backend is not supported by V1 NPUModelRunner.") self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) + vllm_config=self.vllm_config, + device=self.device, + runner=weakref.proxy(self)) # Multi-modal data support self.input_registry = INPUT_REGISTRY @@ -978,9 +979,6 @@ def _process_reqs( # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly - if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': - # SpecDecoding now supports seq_len=1 and seq_len=2 - attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): attn_state = AscendAttentionState.SpecDecoding @@ -1015,19 +1013,13 @@ def _process_reqs( self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - query_start_loc = self.query_start_loc[:num_reqs + 1] # Use host tensor, other wise error: tensor.hostData is null - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - seq_lens=self.seq_lens_cpu[:num_reqs]) - self.common_attn_metadata = common_attn_metadata self.seq_lens_list = self.seq_lens_np.tolist()[:num_input_tokens] with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] is_only_prefill = bool(np.all(num_valid_tokens != 1)) - extra_builder_kwargs['is_only_prefill'] = is_only_prefill enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, @@ -1041,7 +1033,29 @@ def _process_reqs( enable_dbo) = self._get_forward_metadata_across_dp( maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill, enable_dbo) - extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=self.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=self. + slot_mapping_cpu[:total_num_scheduled_tokens], + positions=self.positions[:num_input_tokens], + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, # type: ignore + decode_token_per_req=self.decode_token_per_req, + max_num_blocks_per_req=self.max_num_blocks_per_req, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=is_only_prefill, + ) # TODO(zzzzwwjj): this code need to refactor afterwards. self.with_prefill = with_prefill @@ -1060,25 +1074,11 @@ def _process_reqs( self.extra_builder_kwargs = extra_builder_kwargs self.num_tokens_across_dp = num_tokens_across_dp - if self.vllm_config.model_config.use_mla: - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_attn_metadata=common_attn_metadata, - common_prefix_len=None, - **extra_builder_kwargs, - ) - else: - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_attn_metadata=common_attn_metadata, - common_prefix_len=None, - **extra_builder_kwargs, - ) - attn_metadata.num_input_tokens = num_input_tokens + attn_metadata = self.attn_metadata_builder.build( # type: ignore + common_attn_metadata=common_attn_metadata, + **extra_builder_kwargs, + ) + attn_metadata.num_input_tokens = padded_num_tokens_across_dp # Prepare input_ids token_indices = (positions_np + @@ -2262,7 +2262,6 @@ def _generate_mtp_token_ids( cu_num_tokens, token_indices = self.drafter.prepare_inputs( attn_metadata.query_start_loc, num_rejected_tokens, - force_one_token=False, is_torchair_graph=self.torchair_graph_enabled) if self.torchair_graph_enabled: # the seq len of each bath is padded to 2, thus input is same as the main model diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index aa9612cc88..81691e09de 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -16,8 +16,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.utils import \ - AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.distributed.utils import is_lmhead_tp from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.utils import ProfileExecuteDuration @@ -92,7 +91,6 @@ def prepare_inputs( cu_target_query_lens: torch.Tensor, # [batch_size] num_rejected_tokens: torch.Tensor, - force_one_token: bool = False, is_torchair_graph: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: # cu_target_query_lens: [0, a, a + b, a + b + c] @@ -111,14 +109,6 @@ def prepare_inputs( cu_num_tokens = cu_target_query_lens relative_index = query_len_per_req - num_rejected_tokens - 1 token_indices = cu_num_tokens[:-1] + relative_index - elif force_one_token: - # enable force_one_token means we only focus on the last token position of each request - # token_indices: [batch_size] - cu_num_tokens = torch.arange(cu_target_query_lens.size(0), - device=cu_target_query_lens.device, - dtype=torch.int32) - relative_index = query_len_per_req - num_rejected_tokens - 1 - token_indices = cu_target_query_lens[:-1] + relative_index else: cu_num_tokens = torch.empty_like(cu_target_query_lens) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) @@ -126,7 +116,7 @@ def prepare_inputs( # FIXME(woosuk): Avoid synchronization. num_tokens = cu_num_tokens[-1].item() - token_indices = torch.empty( + token_indices = torch.zeros( num_tokens, dtype=torch.int32, device=cu_num_tokens.device, @@ -170,11 +160,6 @@ def propose( # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] if token_indices is not None and self.runner.torchair_graph_enabled: last_token_indices = token_indices - common_attn_metadata = self.runner.common_attn_metadata - else: - seq_lens = (target_positions[last_token_indices] + 1) - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=cu_num_tokens, seq_lens=seq_lens) self.input_ids[last_token_indices] = next_token_ids @@ -207,13 +192,29 @@ def propose( extra_builder_kwargs['num_reqs_pad_size'] = 0 num_input_tokens = num_tokens - attn_metadata = self.runner.attn_metadata_builder.build( + seq_lens = target_positions[last_token_indices] + 1 + seq_lens = seq_lens.int() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=cu_num_tokens[:batch_size + 1], + query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - **extra_builder_kwargs) + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + max_num_blocks_per_req=self.runner.max_num_blocks_per_req, + ) + attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata, **extra_builder_kwargs) self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states