diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 594649c6d4..8d3b5566ec 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1003,19 +1003,19 @@ def _gather_mm_embeddings( mm_embeds.append(mm_embeds_item) return mm_embeds - def _process_reqs( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray, - Optional[set[str]], Optional[set[str]]]: - # Check input valid + def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None): + # get total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 + + # get num_reqs num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + + # get num_input_tokens if (self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]): # Add padding to the batch size. @@ -1025,9 +1025,8 @@ def _process_reqs( # Eager mode. num_input_tokens = total_num_scheduled_tokens - modified_batch = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) - if modified_batch: + if self.attn_metadata_builder.reorder_batch(self.input_batch, + scheduler_output): self.input_batch.refresh_sampling_metadata() # OPTIMIZATION: Start copying the block table first. @@ -1057,9 +1056,6 @@ def _process_reqs( cu_num_tokens = np.cumsum(num_scheduled_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) - logits_indices = cu_num_tokens - 1 - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets positions_np = self.positions_np[:total_num_scheduled_tokens] @@ -1067,27 +1063,13 @@ def _process_reqs( arange, out=positions_np) - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) - - if self.uses_mrope: - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True) - self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - positions = self.positions[:num_input_tokens] self.query_lens = torch.from_numpy(num_scheduled_tokens) - self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) - seq_lens = self.seq_lens_cpu[:num_reqs] block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) @@ -1100,8 +1082,6 @@ def _process_reqs( out=self.slot_mapping_np[:total_num_scheduled_tokens]) ascend_config = get_ascend_config() - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. @@ -1124,14 +1104,12 @@ def _process_reqs( attn_state = AscendAttentionState.PrefillCacheHit self.attn_mask = self._make_attention_mask( - seq_lens=seq_lens, + seq_lens=self.seq_lens_cpu[:num_reqs], query_lens=num_scheduled_tokens, - position=positions, + position=self.positions[:num_input_tokens], attn_state=attn_state) self.attn_state = attn_state # type: ignore - extra_builder_kwargs = {} - self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc[:num_reqs + 1].copy_( @@ -1147,20 +1125,17 @@ def _process_reqs( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] - is_only_prefill = bool(np.all(num_valid_tokens != 1)) - extra_builder_kwargs['is_only_prefill'] = is_only_prefill - - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), - attn_state, - total_num_scheduled_tokens) - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, enable_dbo) = self._get_forward_metadata_across_dp_and_pad( total_num_scheduled_tokens, with_prefill, enable_dbo) + + extra_builder_kwargs = {} extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo + extra_builder_kwargs['is_only_prefill'] = bool( + np.all(num_valid_tokens != 1)) self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp if self.torchair_graph_enabled and not with_prefill: @@ -1205,10 +1180,6 @@ def _process_reqs( # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - - if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1228,15 +1199,22 @@ def _process_reqs( # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the ACL graph. - input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + if self.torchair_graph_enabled and not with_prefill: + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] + else: + input_ids = self.input_ids[:num_input_tokens] + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) positions = self.mrope_positions[:, :num_input_tokens] - if self.torchair_graph_enabled and not with_prefill: - input_ids = self.input_ids[:padded_num_tokens_across_dp] - positions = self.positions[:padded_num_tokens_across_dp] - if get_pp_group().is_first_rank: intermediate_tensors = None else: @@ -1250,79 +1228,9 @@ def _process_reqs( for k, v in self.intermediate_tensors.items() }) - # Run forward pass - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=padded_num_tokens_across_dp, - num_tokens_across_dp=num_tokens_across_dp, - with_prefill=with_prefill, - num_actual_tokens=total_num_scheduled_tokens): - with ProfileExecuteDuration().capture_async("forward"): - self.maybe_setup_kv_connector(scheduler_output) - model_kwargs = {} - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled and not with_prefill: - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_NZ) - - compiled_model = self._get_torchair_lazy_compiled_model( - padded_num_tokens_across_dp) - hidden_states = compiled_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - else: - assert self.model is not None - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = self.get_finished_kv_transfer( - scheduler_output) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if not use_spec_decode: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - spec_decode_metadata = None - else: - # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) - logits_indices = spec_decode_metadata.logits_indices - - aux_hidden_states = None - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = hidden_states - - return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens, finished_sending, finished_recving) + return (attn_metadata, positions, num_scheduled_tokens, + padded_num_tokens_across_dp, num_tokens_across_dp, + with_prefill, input_ids, inputs_embeds, cu_num_tokens) def _get_cumsum_and_arange( self, @@ -1567,28 +1475,115 @@ def _pool( **extra_args, ) + def _execute_model(self, attn_metadata, padded_num_tokens_across_dp, + with_prefill, scheduler_output, input_ids, positions, + intermediate_tensors, inputs_embeds): + # Run forward pass + self.maybe_setup_kv_connector(scheduler_output) + model_kwargs = {} + if self.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + if self.torchair_graph_enabled and not with_prefill: + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_NZ) + + compiled_model = self._get_torchair_lazy_compiled_model( + padded_num_tokens_across_dp) + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + else: + assert self.model is not None + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + return hidden_states + @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - with ProfileExecuteDuration().capture_async( - "prepare input and forward"): - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - logger.debug( - "skip this step for we receive the data from remote disaggregate prefill node" - ) - # Return empty ModelRunnerOuptut if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output) - (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens_np, finished_sending, - finished_recving) = (self._process_reqs(scheduler_output, - intermediate_tensors)) + # 1. Update input batch + self._update_states(scheduler_output) + + # If nothing to do, return directly. + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + logger.debug( + "skip this step for we receive the data from remote disaggregate prefill node" + ) + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output) + + # 2. Prepare forward input + with ProfileExecuteDuration().capture_async("prepare input"): + (attn_metadata, positions, num_scheduled_tokens_np, + padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, + input_ids, inputs_embeds, + cu_num_tokens) = (self._prepare_inputs(scheduler_output, + intermediate_tensors)) + + # 3. Run forward + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=padded_num_tokens_across_dp, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens): + with ProfileExecuteDuration().capture_async("forward"): + hidden_states = self._execute_model( + attn_metadata, padded_num_tokens_across_dp, with_prefill, + scheduler_output, input_ids, positions, + intermediate_tensors, inputs_embeds) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer( + scheduler_output) + + if len(scheduler_output.scheduled_spec_decode_tokens) > 0: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(self.input_batch.num_reqs, + dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + else: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to( + self.device, non_blocking=True) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = hidden_states + else: + aux_hidden_states = None + kv_connector_output = None if not vllm_version_is("0.10.0"): if finished_sending is not None and finished_recving is not None: @@ -1623,10 +1618,11 @@ def execute_model( logits = None else: if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, - finished_sending, finished_recving, - kv_connector_output) + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, finished_sending, + finished_recving, kv_connector_output) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: @@ -1702,7 +1698,7 @@ def execute_model( # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:num_scheduled_tokens], + hidden_states[:scheduler_output.total_num_scheduled_tokens], scheduler_output, ) @@ -1751,7 +1747,7 @@ def execute_model( scheduler_output, spec_decode_metadata, positions, - num_scheduled_tokens, + scheduler_output.total_num_scheduled_tokens, hidden_states, attn_metadata, aux_hidden_states,