From e11bbd17c266c4fd38fac3dc70578622ac173a9b Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 5 Aug 2025 10:51:52 +0300 Subject: [PATCH 01/15] Lookahead decoding initial commit --- vllm_gaudi/v1/worker/hpu_model_runner.py | 222 +++++++++++++++++++---- 1 file changed, 191 insertions(+), 31 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 60602ea3..dd15b270 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -653,6 +653,12 @@ def __init__( self.profiler = HabanaHighLevelProfiler() self.profiler_counter_helper = HabanaProfilerCounterHelper() + # Lookahead decoding + self.use_lookahead_decoding = True + # Storage for lookahead tokens that are computed but not yet scheduled + self.lookahead_tokens: dict[str, list[int]] = {} + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -705,6 +711,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + # Pop stored lookahead tokens for finished requests - dummy tokens at the end + self.lookahead_tokens.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -787,7 +795,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: elif num_new_tokens > 0: req_state.output_token_ids.extend( new_token_ids[-num_new_tokens:]) - + + # Check if we already have lookahead tokens for this request + if (self.use_lookahead_decoding and + req_id in self.lookahead_tokens and + len(self.lookahead_tokens[req_id]) > 0): + # Use the first available lookahead token + lookahead_token = self.lookahead_tokens[req_id].pop(0) + req_state.output_token_ids.append(lookahead_token) + # Clean up empty lookahead token lists + if len(self.lookahead_tokens[req_id]) == 0: + del self.lookahead_tokens[req_id] # Update the block IDs. if not resumed_from_preemption: for block_ids, new_ids in zip(req_state.block_ids, @@ -899,6 +917,7 @@ def _get_prompts_and_decodes( # Traverse prompts prompt_req_ids = [] prompt_scheduled_tokens = [] + lookahead_decode_req_ids = [] for i in range(len(decode_req_ids), num_reqs): req_id = self.input_batch.req_ids[i] assert req_id is not None @@ -916,8 +935,14 @@ def _get_prompts_and_decodes( prompt_req_ids.append(req_id) prompt_scheduled_tokens.append(num_scheduled_tokens) - - return PromptDecodeInfo(prompt_req_ids, decode_req_ids, + + # Schedule lookahead decode for this prompt + if self.use_lookahead_decoding: + lookahead_decode_req_ids.append(req_id) + + all_decode_req_ids = decode_req_ids + lookahead_decode_req_ids + + return PromptDecodeInfo(prompt_req_ids, all_decode_req_ids, prompt_scheduled_tokens) def _prepare_sampling(self, @@ -1039,15 +1064,26 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_scheduled_tokens): # DECODES are the first num_decodes REQUESTS. # PREFILLS are the next num_reqs - num_decodes REQUESTS. - num_reqs = num_prefills + num_decodes + if self.use_lookahead_decoding: + # When lookahead decoding is used, num_decodes includes lookahead decodes + # but we need to find where actual prefills start in the batch + actual_num_decodes = num_decodes - num_prefills + idx_bias = num_prefills + else: + actual_num_decodes = num_decodes + idx_bias = 0 block_table_cpu_tensor = self.input_batch.block_table[ 0].get_cpu_tensor() all_batch_contents = [BatchContents()] - for batch_idx in range(num_decodes, num_reqs): + # Prefills start after the actual decode requests in the input_batch + prefill_start = actual_num_decodes + prefill_end = prefill_start + num_prefills + + for batch_idx in range(prefill_start, prefill_end): req_id = self.input_batch.req_ids[batch_idx] context_len = self.input_batch.num_computed_tokens_cpu[batch_idx] - query_len = num_scheduled_tokens[batch_idx] + query_len = num_scheduled_tokens[batch_idx+idx_bias] token_ids = self.input_batch.token_ids_cpu[ batch_idx, context_len:context_len + query_len].tolist() @@ -1229,7 +1265,7 @@ def _prepare_prefill_inputs( merge_contents(all_batches[0], *all_batches[1:]) return all_batches[0] - def _prepare_decode_inputs(self, num_decodes, + def _prepare_decode_inputs(self, num_prefills, num_decodes, num_scheduled_tokens) -> DecodeInputData: # Decodes run as one single padded batch with shape [batch, 1] # @@ -1242,8 +1278,27 @@ def _prepare_decode_inputs(self, num_decodes, 0].get_cpu_tensor() if num_decodes == 0: return DecodeInputData(num_decodes=0) + + # Calculate original decodes count (before lookahead) + original_num_decodes = num_decodes + if self.use_lookahead_decoding: + original_num_decodes = num_decodes - num_prefills + # BLOCK_TABLE [batch, max_num_blocks_per_req] - context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes] + context_lens = [] + + # Handle regular decodes + for i in range(original_num_decodes): + context_lens.append(self.input_batch.num_computed_tokens_cpu[i]) + + # Handle lookahead decodes + if self.use_lookahead_decoding: + # For lookahead decodes use the prompt length + 1 (the newly generated token) + prompt_start_idx = original_num_decodes + for i in range(prompt_start_idx, self.input_batch.num_reqs): + context_lens.append(self.input_batch.num_prompt_tokens[i] + 1) + + context_lens = np.array(context_lens) # NOTE(kzawora): the +1 is what causes this entire thing to work, # as in the paged attention, we don't fetch just the context from cache, @@ -1265,10 +1320,18 @@ def _prepare_decode_inputs(self, num_decodes, # POSITIONS. [batch, 1] # We slice at the end, since we use the positions for gathering. positions = torch.zeros((padded_batch_size, 1), dtype=torch.int32) - positions[:num_decodes] = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu.reshape(-1, - 1)[:num_decodes]) - positions = positions[:padded_batch_size] + + # Handle regular decodes + for i in range(original_num_decodes): + positions[i, 0] = self.input_batch.num_computed_tokens_cpu[i] + + # Handle lookahead decodes positions + if self.use_lookahead_decoding: + prompt_start_idx = original_num_decodes + for i in range(prompt_start_idx, num_decodes): + batch_idx = original_num_decodes + (i - prompt_start_idx) + # Position should be at the newly generated token (prompt_len) + positions[i, 0] = self.input_batch.num_prompt_tokens[batch_idx] padded_index = torch.zeros((padded_batch_size, 1), dtype=torch.int64) index = positions.to(torch.int64)[:num_decodes] @@ -1276,10 +1339,19 @@ def _prepare_decode_inputs(self, num_decodes, # TOKEN_IDS. [batch, 1] token_ids = torch.zeros((padded_batch_size, 1), dtype=torch.int32) - token_ids[:num_decodes] = torch.gather(input=torch.from_numpy( - self.input_batch.token_ids_cpu), - dim=1, - index=index) + + # Handle regular decodes + for i in range(original_num_decodes): + pos = int(positions[i, 0]) + token_ids[i, 0] = self.input_batch.token_ids_cpu[i, pos] + + # Handle lookahead decode tokens (from prompts) + if self.use_lookahead_decoding: + prompt_start_idx = original_num_decodes + for i in range(prompt_start_idx, num_decodes): + batch_idx = original_num_decodes + (i - prompt_start_idx) + # For lookahead use the newly generated token from prefill phase + token_ids[i, 0] = self.input_batch.token_ids_cpu[batch_idx, prompt_start_idx] # SLOT_MAPPING [batch, 1] # The "slot" is the "physical index" of a token in the KV cache. @@ -1345,6 +1417,41 @@ def _prepare_decode_inputs(self, num_decodes, block_size=self.block_size, )) + def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, + prefill_sampled_tokens, prefill_sampled_requests): + """ + Update decode_data for lookahead decoding to use newly generated tokens from prefill phase + instead of stale prompt tokens as inputs. + """ + if not self.use_lookahead_decoding or num_prefills == 0 or not prefill_sampled_tokens: + return + + # Get the original decode count (before lookahead decodes) + original_num_decodes = num_decodes - num_prefills + + # Create a mapping from request_id to newly generated token + req_to_token = {} + for token_batch, req_batch in zip(prefill_sampled_tokens, prefill_sampled_requests): + tokens = token_batch.cpu() if hasattr(token_batch, 'cpu') else token_batch + for token, req_id in zip(tokens, req_batch): + req_to_token[req_id] = token + + # For lookahead decodes (which start after original decodes), update their input tokens + for i in range(original_num_decodes, num_decodes): + # Map lookahead decode index to corresponding prefill request + prefill_idx = i - original_num_decodes + batch_idx = original_num_decodes + prefill_idx + req_id = self.input_batch.req_ids[batch_idx] + + if req_id in req_to_token: + # Update the decode input to use the newly generated token + new_token = req_to_token[req_id] + decode_data.token_ids[i, 0] = new_token + + # Update position to point to the correct location (after the generated token) + prompt_len = self.input_batch.num_prompt_tokens[batch_idx] + decode_data.position_ids[i, 0] = prompt_len + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1357,6 +1464,10 @@ def _prepare_inputs( num_reqs = num_prefills + num_decodes + actual_num_decodes = num_decodes + if self.use_lookahead_decoding: + actual_num_decodes = num_decodes - num_prefills + # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] @@ -1369,11 +1480,17 @@ def _prepare_inputs( num_scheduled_tokens.append(seq_num_scheduled_tokens) num_prompt_tokens.append(seq_num_prompt_tokens) # NOTE: assert that all the decodes are "decodes". - if idx < num_decodes: + if idx < actual_num_decodes: assert seq_num_scheduled_tokens == 1 + + if self.use_lookahead_decoding: + # Insert scheduled tokens for lookahead decodes (always 1 token per decode) + for _ in range(num_prefills): + num_scheduled_tokens.insert(0, 1) return (self._prepare_prefill_inputs(num_prefills, num_decodes, num_scheduled_tokens), - self._prepare_decode_inputs(num_decodes, num_scheduled_tokens)) + self._prepare_decode_inputs(num_prefills, num_decodes, + num_scheduled_tokens)) def _seq_len(self, attn_metadata): return attn_metadata.slot_mapping.size(-1) @@ -1488,7 +1605,7 @@ def _get_prompt_logprobs_dict( # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want - # to gather the logprob for. + # to gather the logprobs for. tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] # Compute prompt logprobs. @@ -1576,7 +1693,6 @@ def execute_model( # Transfer [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] to CPU # On CPU, sanitize [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] -> [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] # noqa # Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] - batch_changed = self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOuptut if there's no work to do. @@ -1641,6 +1757,8 @@ def execute_model( ######################### DECODES ######################### # Decodes run as one single batch with [padded_decode_bs, 1] if num_decodes > 0: + self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, + prefill_sampled_token_ids, prefill_sampled_requests) self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "decode") assert decode_data is not None @@ -1659,8 +1777,19 @@ def execute_model( logits=logits_device, sampling_metadata=sampling_metadata) decode_sampled_token_ids.append( sampler_output.sampled_token_ids.flatten()) - decode_sampled_requests.extend( - self.input_batch.req_ids[:num_decodes]) + + # TODO remove lookahead syncs with cpu + if self.use_lookahead_decoding: + original_num_decodes = num_decodes - num_prefills + # Regular decode requests + decode_sampled_requests.extend( + self.input_batch.req_ids[:original_num_decodes]) + # Lookahead decode requests + for i in range(num_prefills): + decode_sampled_requests.append(pd_info.prompt_req_ids[i]) + else: + decode_sampled_requests.extend( + self.input_batch.req_ids[:num_decodes]) htorch.core.mark_step() if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event @@ -1697,25 +1826,52 @@ def execute_model( postprocessed_sampled_token_ids = [[] for _ in range(max_req_index + 1)] + + lookahead_token_mapping = {} # req_id : token_id + scheduled_token_mapping = {} # req_id : [token_ids] + + # Lookahead tokens storing - should be done without host syncs for tok_id, req_id in zip(sampled_token_ids_list, sampled_token_requests): - postprocessed_sampled_token_ids[ - self.input_batch.req_id_to_index[req_id]].append(tok_id) + req_index = self.input_batch.req_id_to_index[req_id] + + if (self.use_lookahead_decoding and + req_id in pd_info.prompt_req_ids and + req_id in decode_sampled_requests): + if req_id not in lookahead_token_mapping: + lookahead_token_mapping[req_id] = [] + lookahead_token_mapping[req_id].append(tok_id) + else: + postprocessed_sampled_token_ids[req_index].append(tok_id) + if req_id not in scheduled_token_mapping: + scheduled_token_mapping[req_id] = [] + scheduled_token_mapping[req_id].append(tok_id) # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs + # Store lookahead tokens for future use - can be done better + if self.use_lookahead_decoding: + for req_id, tokens in lookahead_token_mapping.items(): + if req_id not in self.lookahead_tokens: + self.lookahead_tokens[req_id] = [] + self.lookahead_tokens[req_id].extend(tokens) + ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### for req_id in self.input_batch.req_ids[:num_reqs]: + if req_id not in scheduled_token_mapping: + continue + req_state = self.requests[req_id] i = self.input_batch.req_id_to_index[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - token_ids = postprocessed_sampled_token_ids[i] + token_ids = scheduled_token_mapping[req_id] num_tokens = len(token_ids) - self.input_batch.token_ids_cpu[i, seq_len:seq_len + - num_tokens] = token_ids - self.input_batch.num_tokens[i] += len(token_ids) - req_state.output_token_ids.extend(token_ids) + if num_tokens > 0: + self.input_batch.token_ids_cpu[i, seq_len:seq_len + + num_tokens] = token_ids + self.input_batch.num_tokens[i] += len(token_ids) + req_state.output_token_ids.extend(token_ids) # NOTE(chendi): enable cache based on PR(#20291) # Cache the sampled tokens in the model runner, so that the scheduler @@ -1744,12 +1900,16 @@ def execute_model( req_state.output_token_ids.extend(sampled_ids) ################## RETURN ################## # Create output. - all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids + # Only return the originally scheduled requests, not the lookahead requests + original_decode_req_ids = pd_info.decode_req_ids + if self.use_lookahead_decoding: + original_decode_req_ids = pd_info.decode_req_ids[:-num_prefills] if num_prefills > 0 else pd_info.decode_req_ids + + all_req_ids = original_decode_req_ids + pd_info.prompt_req_ids #prompt_logprobs_dict: dict[ # str, Optional[LogprobsTensors]] = self._get_prompt_logprobs_dict( # prefill_hidden_states_device, scheduler_output) prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} - all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids logprobs = None model_runner_output = ModelRunnerOutput( From c14369dc5088184d44bcf54497cb2902c3769543 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Mon, 11 Aug 2025 12:44:26 +0300 Subject: [PATCH 02/15] Lookahead decoding patch 2 --- vllm_gaudi/v1/worker/hpu_model_runner.py | 134 +++++++++++------------ 1 file changed, 61 insertions(+), 73 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index dd15b270..02845746 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -656,7 +656,7 @@ def __init__( # Lookahead decoding self.use_lookahead_decoding = True # Storage for lookahead tokens that are computed but not yet scheduled - self.lookahead_tokens: dict[str, list[int]] = {} + self.lookahead_tokens: dict = {} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: @@ -795,17 +795,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: elif num_new_tokens > 0: req_state.output_token_ids.extend( new_token_ids[-num_new_tokens:]) - - # Check if we already have lookahead tokens for this request - if (self.use_lookahead_decoding and - req_id in self.lookahead_tokens and - len(self.lookahead_tokens[req_id]) > 0): - # Use the first available lookahead token - lookahead_token = self.lookahead_tokens[req_id].pop(0) - req_state.output_token_ids.append(lookahead_token) - # Clean up empty lookahead token lists - if len(self.lookahead_tokens[req_id]) == 0: - del self.lookahead_tokens[req_id] + # Update the block IDs. if not resumed_from_preemption: for block_ids, new_ids in zip(req_state.block_ids, @@ -1431,10 +1421,9 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, # Create a mapping from request_id to newly generated token req_to_token = {} - for token_batch, req_batch in zip(prefill_sampled_tokens, prefill_sampled_requests): - tokens = token_batch.cpu() if hasattr(token_batch, 'cpu') else token_batch - for token, req_id in zip(tokens, req_batch): - req_to_token[req_id] = token + #### ACCURACY THREAT ????????????? + for token, req_id in zip(prefill_sampled_tokens, prefill_sampled_requests): + req_to_token[req_id] = token # For lookahead decodes (which start after original decodes), update their input tokens for i in range(original_num_decodes, num_decodes): @@ -1451,6 +1440,12 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, # Update position to point to the correct location (after the generated token) prompt_len = self.input_batch.num_prompt_tokens[batch_idx] decode_data.position_ids[i, 0] = prompt_len + + # Replace tokens in regular decodes with lookahead stored tokens + for i in range(0, original_num_decodes): + req_id = self.input_batch.req_ids[i] + decode_data.token_ids[i, 0] = self.lookahead_tokens.get(req_id, 0)[0] + def _prepare_inputs( self, @@ -1634,6 +1629,13 @@ def _is_quant_with_inc(self): quant_config = os.getenv("QUANT_CONFIG", None) is not None return (self.model_config.quantization == "inc" or quant_config) + def is_chunked_prefill_dummy_output_token( + self, + req_id, + prefill_sampled_requests, + prompt_req_ids) -> bool: + return (req_id in prompt_req_ids) and (req_id not in prefill_sampled_requests) + @torch.inference_mode() def execute_model( self, @@ -1703,7 +1705,10 @@ def execute_model( pd_info = self._get_prompts_and_decodes(scheduler_output) num_decodes = len(pd_info.decode_req_ids) num_prefills = len(pd_info.prompt_req_ids) - num_reqs = num_decodes + num_prefills + if num_decodes > num_prefills: + pass + original_num_decodes = num_decodes - num_prefills + num_reqs = original_num_decodes + num_prefills with self.profiler.record_event('internal', 'prepare_input_tensors'): prefill_data, decode_data = self._prepare_inputs( scheduler_output, num_prefills, num_decodes) @@ -1711,7 +1716,6 @@ def execute_model( # later. prefill_sampled_token_ids = [] prefill_sampled_requests = [] - decode_sampled_token_ids = [] decode_sampled_requests = [] ######################### PREFILLS ######################### if num_prefills > 0: @@ -1775,21 +1779,25 @@ def execute_model( pad_to=logits_device.shape[0]) sampler_output = self.sampler( logits=logits_device, sampling_metadata=sampling_metadata) - decode_sampled_token_ids.append( - sampler_output.sampled_token_ids.flatten()) - + decode_sampled_token_ids = \ + sampler_output.sampled_token_ids.flatten() + + #if num_decodes > num_prefills: + # import fpdb + # fpdb.ForkedPdb().set_trace() + for req_id, token_ids in zip( + pd_info.decode_req_ids, + decode_sampled_token_ids[:num_decodes].split(1)): + if not self.is_chunked_prefill_dummy_output_token(req_id, + prefill_sampled_requests, + pd_info.prompt_req_ids): + if not req_id in self.lookahead_tokens: + self.lookahead_tokens[req_id] = [] + self.lookahead_tokens[req_id].append(token_ids) # TODO remove lookahead syncs with cpu - if self.use_lookahead_decoding: - original_num_decodes = num_decodes - num_prefills - # Regular decode requests - decode_sampled_requests.extend( - self.input_batch.req_ids[:original_num_decodes]) - # Lookahead decode requests - for i in range(num_prefills): - decode_sampled_requests.append(pd_info.prompt_req_ids[i]) - else: - decode_sampled_requests.extend( - self.input_batch.req_ids[:num_decodes]) + decode_sampled_requests.extend( + self.input_batch.req_ids[:(num_decodes-num_prefills)]) + htorch.core.mark_step() if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event @@ -1810,17 +1818,9 @@ def execute_model( # We already have tokens. Let's copy the data to # CPU as is, and then discard padded tokens. with self.profiler.record_event('internal', "sampler_postprocessing"): - prefill_sampled_token_ids = [ + prefill_sampled_token_ids = torch.cat([ tensor.cpu() for tensor in prefill_sampled_token_ids - ] - decode_sampled_token_ids = [ - tensor.cpu()[:num_decodes] - for tensor in decode_sampled_token_ids - ] - sampled_token_ids_list = torch.cat( - decode_sampled_token_ids + prefill_sampled_token_ids).tolist() - sampled_token_requests = \ - decode_sampled_requests + prefill_sampled_requests + ]).tolist() max_req_index = max(self.input_batch.req_id_to_index.values()) postprocessed_sampled_token_ids: list[list] postprocessed_sampled_token_ids = [[] @@ -1829,49 +1829,37 @@ def execute_model( lookahead_token_mapping = {} # req_id : token_id scheduled_token_mapping = {} # req_id : [token_ids] - - # Lookahead tokens storing - should be done without host syncs - for tok_id, req_id in zip(sampled_token_ids_list, - sampled_token_requests): + for req_id in decode_sampled_requests: req_index = self.input_batch.req_id_to_index[req_id] + tok_id = self.lookahead_tokens[req_id].pop(0).item() + postprocessed_sampled_token_ids[req_index].append(tok_id) + if req_id not in scheduled_token_mapping: + scheduled_token_mapping[req_id] = [] + scheduled_token_mapping[req_id].append(tok_id) + + for tok_id, req_id in zip(prefill_sampled_token_ids, + prefill_sampled_requests): + req_index = self.input_batch.req_id_to_index[req_id] + postprocessed_sampled_token_ids[req_index].append(tok_id) + if req_id not in scheduled_token_mapping: + scheduled_token_mapping[req_id] = [] + scheduled_token_mapping[req_id].append(tok_id) - if (self.use_lookahead_decoding and - req_id in pd_info.prompt_req_ids and - req_id in decode_sampled_requests): - if req_id not in lookahead_token_mapping: - lookahead_token_mapping[req_id] = [] - lookahead_token_mapping[req_id].append(tok_id) - else: - postprocessed_sampled_token_ids[req_index].append(tok_id) - if req_id not in scheduled_token_mapping: - scheduled_token_mapping[req_id] = [] - scheduled_token_mapping[req_id].append(tok_id) - # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs - # Store lookahead tokens for future use - can be done better - if self.use_lookahead_decoding: - for req_id, tokens in lookahead_token_mapping.items(): - if req_id not in self.lookahead_tokens: - self.lookahead_tokens[req_id] = [] - self.lookahead_tokens[req_id].extend(tokens) ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### for req_id in self.input_batch.req_ids[:num_reqs]: - if req_id not in scheduled_token_mapping: - continue - req_state = self.requests[req_id] i = self.input_batch.req_id_to_index[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - token_ids = scheduled_token_mapping[req_id] + token_ids = postprocessed_sampled_token_ids[i] num_tokens = len(token_ids) - if num_tokens > 0: - self.input_batch.token_ids_cpu[i, seq_len:seq_len + - num_tokens] = token_ids - self.input_batch.num_tokens[i] += len(token_ids) - req_state.output_token_ids.extend(token_ids) + self.input_batch.token_ids_cpu[i, seq_len:seq_len + + num_tokens] = token_ids + self.input_batch.num_tokens[i] += len(token_ids) + req_state.output_token_ids.extend(token_ids) # NOTE(chendi): enable cache based on PR(#20291) # Cache the sampled tokens in the model runner, so that the scheduler From 5e195774490172aa46da8a8f6bcf2a1d3d14fd02 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Mon, 11 Aug 2025 14:03:29 +0300 Subject: [PATCH 03/15] Lookahead decoding part 3 --- vllm_gaudi/v1/worker/hpu_model_runner.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 02845746..d93c1802 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1422,7 +1422,8 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, # Create a mapping from request_id to newly generated token req_to_token = {} #### ACCURACY THREAT ????????????? - for token, req_id in zip(prefill_sampled_tokens, prefill_sampled_requests): + tokens = [token for token in prefill_sampled_tokens if token.shape != torch.Size([0])] + for token, req_id in zip(tokens, prefill_sampled_requests): req_to_token[req_id] = token # For lookahead decodes (which start after original decodes), update their input tokens @@ -1431,12 +1432,10 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, prefill_idx = i - original_num_decodes batch_idx = original_num_decodes + prefill_idx req_id = self.input_batch.req_ids[batch_idx] - if req_id in req_to_token: # Update the decode input to use the newly generated token new_token = req_to_token[req_id] decode_data.token_ids[i, 0] = new_token - # Update position to point to the correct location (after the generated token) prompt_len = self.input_batch.num_prompt_tokens[batch_idx] decode_data.position_ids[i, 0] = prompt_len @@ -1705,8 +1704,6 @@ def execute_model( pd_info = self._get_prompts_and_decodes(scheduler_output) num_decodes = len(pd_info.decode_req_ids) num_prefills = len(pd_info.prompt_req_ids) - if num_decodes > num_prefills: - pass original_num_decodes = num_decodes - num_prefills num_reqs = original_num_decodes + num_prefills with self.profiler.record_event('internal', 'prepare_input_tensors'): @@ -1782,9 +1779,6 @@ def execute_model( decode_sampled_token_ids = \ sampler_output.sampled_token_ids.flatten() - #if num_decodes > num_prefills: - # import fpdb - # fpdb.ForkedPdb().set_trace() for req_id, token_ids in zip( pd_info.decode_req_ids, decode_sampled_token_ids[:num_decodes].split(1)): @@ -1818,9 +1812,10 @@ def execute_model( # We already have tokens. Let's copy the data to # CPU as is, and then discard padded tokens. with self.profiler.record_event('internal', "sampler_postprocessing"): - prefill_sampled_token_ids = torch.cat([ - tensor.cpu() for tensor in prefill_sampled_token_ids - ]).tolist() + if num_prefills > 0: + prefill_sampled_token_ids = torch.cat([ + tensor.cpu() for tensor in prefill_sampled_token_ids + ]).tolist() max_req_index = max(self.input_batch.req_id_to_index.values()) postprocessed_sampled_token_ids: list[list] postprocessed_sampled_token_ids = [[] From 0151e1936c3d183e7e8a2b45147c4af3221692bc Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Mon, 11 Aug 2025 16:34:52 +0300 Subject: [PATCH 04/15] Acc fix --- vllm_gaudi/v1/worker/hpu_model_runner.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d93c1802..9a0b4c5c 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1413,7 +1413,7 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, Update decode_data for lookahead decoding to use newly generated tokens from prefill phase instead of stale prompt tokens as inputs. """ - if not self.use_lookahead_decoding or num_prefills == 0 or not prefill_sampled_tokens: + if not self.use_lookahead_decoding: return # Get the original decode count (before lookahead decodes) @@ -1439,13 +1439,12 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, # Update position to point to the correct location (after the generated token) prompt_len = self.input_batch.num_prompt_tokens[batch_idx] decode_data.position_ids[i, 0] = prompt_len - + # Replace tokens in regular decodes with lookahead stored tokens for i in range(0, original_num_decodes): req_id = self.input_batch.req_ids[i] decode_data.token_ids[i, 0] = self.lookahead_tokens.get(req_id, 0)[0] - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1822,23 +1821,15 @@ def execute_model( for _ in range(max_req_index + 1)] - lookahead_token_mapping = {} # req_id : token_id - scheduled_token_mapping = {} # req_id : [token_ids] for req_id in decode_sampled_requests: req_index = self.input_batch.req_id_to_index[req_id] tok_id = self.lookahead_tokens[req_id].pop(0).item() postprocessed_sampled_token_ids[req_index].append(tok_id) - if req_id not in scheduled_token_mapping: - scheduled_token_mapping[req_id] = [] - scheduled_token_mapping[req_id].append(tok_id) for tok_id, req_id in zip(prefill_sampled_token_ids, prefill_sampled_requests): req_index = self.input_batch.req_id_to_index[req_id] postprocessed_sampled_token_ids[req_index].append(tok_id) - if req_id not in scheduled_token_mapping: - scheduled_token_mapping[req_id] = [] - scheduled_token_mapping[req_id].append(tok_id) # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs From 6fbe4632f9fc96370a6094185c9cf1bd5f7d976f Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 19 Aug 2025 14:28:23 +0300 Subject: [PATCH 05/15] Make lookahead scheduling optional --- vllm_gaudi/envs.py | 7 -- vllm_gaudi/extension/features.py | 1 + vllm_gaudi/v1/worker/hpu_model_runner.py | 125 +++++++++++++---------- 3 files changed, 72 insertions(+), 61 deletions(-) diff --git a/vllm_gaudi/envs.py b/vllm_gaudi/envs.py index b2a82b3a..051e79ad 100644 --- a/vllm_gaudi/envs.py +++ b/vllm_gaudi/envs.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True - VLLM_HPU_USE_DELAYED_SAMPLING: bool = False VLLM_HPU_FORCE_CHANNEL_FP8: bool = True # The begin-* and end* here are used by the documentation generator @@ -20,12 +19,6 @@ lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in ("1", "true"), - # Use delayed sampling for HPU to reduce host cpu overhead - # between each step. - "VLLM_HPU_USE_DELAYED_SAMPLING": - lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in - ("1", "true"), - # Convert block fp8 to channel fp8 for HPU "VLLM_HPU_FORCE_CHANNEL_FP8": lambda: os.environ.get("VLLM_HPU_FORCE_CHANNEL_FP8", "true").lower() in diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index f6da3421..0da1ccd9 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -71,5 +71,6 @@ def get_features(): Value('exponential_bucketing', True, env_var='VLLM_EXPONENTIAL_BUCKETING'), Value('linear_bucketing', True), Value('bucketing_strategy', FirstEnabled(*bucketing_strategies), env_var_type=choice(*bucketing_strategies)), + Value('lookahead_decoding', False, env_var='VLLM_USE_LOOKAHEAD_DECODING') ] return split_values_and_flags(features) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 9a0b4c5c..bad48600 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import collections import contextlib +import copy import functools import itertools import math @@ -541,6 +542,7 @@ def __init__( ): # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) environment.set_vllm_config(vllm_config) + self.count = 0 self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -654,7 +656,7 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() # Lookahead decoding - self.use_lookahead_decoding = True + self.use_lookahead_decoding = get_config().lookahead_decoding # Storage for lookahead tokens that are computed but not yet scheduled self.lookahead_tokens: dict = {} @@ -711,8 +713,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - # Pop stored lookahead tokens for finished requests - dummy tokens at the end - self.lookahead_tokens.pop(req_id, None) + if self.use_lookahead_decoding: + # Pop stored lookahead tokens for finished requests - dummy tokens at the end + self.lookahead_tokens.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -1284,8 +1287,7 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, # Handle lookahead decodes if self.use_lookahead_decoding: # For lookahead decodes use the prompt length + 1 (the newly generated token) - prompt_start_idx = original_num_decodes - for i in range(prompt_start_idx, self.input_batch.num_reqs): + for i in range(original_num_decodes, num_decodes): context_lens.append(self.input_batch.num_prompt_tokens[i] + 1) context_lens = np.array(context_lens) @@ -1317,31 +1319,21 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, # Handle lookahead decodes positions if self.use_lookahead_decoding: - prompt_start_idx = original_num_decodes - for i in range(prompt_start_idx, num_decodes): - batch_idx = original_num_decodes + (i - prompt_start_idx) + for i in range(original_num_decodes, num_decodes): # Position should be at the newly generated token (prompt_len) - positions[i, 0] = self.input_batch.num_prompt_tokens[batch_idx] + positions[i, 0] = self.input_batch.num_prompt_tokens[i] + positions = positions[:padded_batch_size] padded_index = torch.zeros((padded_batch_size, 1), dtype=torch.int64) index = positions.to(torch.int64)[:num_decodes] padded_index[:num_decodes] = index # TOKEN_IDS. [batch, 1] token_ids = torch.zeros((padded_batch_size, 1), dtype=torch.int32) - - # Handle regular decodes - for i in range(original_num_decodes): - pos = int(positions[i, 0]) - token_ids[i, 0] = self.input_batch.token_ids_cpu[i, pos] - - # Handle lookahead decode tokens (from prompts) - if self.use_lookahead_decoding: - prompt_start_idx = original_num_decodes - for i in range(prompt_start_idx, num_decodes): - batch_idx = original_num_decodes + (i - prompt_start_idx) - # For lookahead use the newly generated token from prefill phase - token_ids[i, 0] = self.input_batch.token_ids_cpu[batch_idx, prompt_start_idx] + token_ids[:num_decodes] = torch.gather(input=torch.from_numpy( + self.input_batch.token_ids_cpu), + dim=1, + index=index) # SLOT_MAPPING [batch, 1] # The "slot" is the "physical index" of a token in the KV cache. @@ -1413,8 +1405,6 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, Update decode_data for lookahead decoding to use newly generated tokens from prefill phase instead of stale prompt tokens as inputs. """ - if not self.use_lookahead_decoding: - return # Get the original decode count (before lookahead decodes) original_num_decodes = num_decodes - num_prefills @@ -1435,15 +1425,17 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, if req_id in req_to_token: # Update the decode input to use the newly generated token new_token = req_to_token[req_id] - decode_data.token_ids[i, 0] = new_token + decode_data.token_ids[i] = new_token # Update position to point to the correct location (after the generated token) prompt_len = self.input_batch.num_prompt_tokens[batch_idx] - decode_data.position_ids[i, 0] = prompt_len + decode_data.position_ids[i, 0] = torch.tensor(prompt_len, dtype=torch.int, device=self.device) # Replace tokens in regular decodes with lookahead stored tokens for i in range(0, original_num_decodes): req_id = self.input_batch.req_ids[i] - decode_data.token_ids[i, 0] = self.lookahead_tokens.get(req_id, 0)[0] + decode_data.token_ids[i] = self.lookahead_tokens.get(req_id, 0)[0] + + return copy.deepcopy(decode_data) def _prepare_inputs( self, @@ -1703,7 +1695,9 @@ def execute_model( pd_info = self._get_prompts_and_decodes(scheduler_output) num_decodes = len(pd_info.decode_req_ids) num_prefills = len(pd_info.prompt_req_ids) - original_num_decodes = num_decodes - num_prefills + original_num_decodes = num_decodes + if self.use_lookahead_decoding: + original_num_decodes = num_decodes - num_prefills num_reqs = original_num_decodes + num_prefills with self.profiler.record_event('internal', 'prepare_input_tensors'): prefill_data, decode_data = self._prepare_inputs( @@ -1713,6 +1707,7 @@ def execute_model( prefill_sampled_token_ids = [] prefill_sampled_requests = [] decode_sampled_requests = [] + decode_sampled_token_ids = [] ######################### PREFILLS ######################### if num_prefills > 0: htorch.core.mark_step() @@ -1757,8 +1752,9 @@ def execute_model( ######################### DECODES ######################### # Decodes run as one single batch with [padded_decode_bs, 1] if num_decodes > 0: - self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, - prefill_sampled_token_ids, prefill_sampled_requests) + if self.use_lookahead_decoding: + decode_data = self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, + prefill_sampled_token_ids, prefill_sampled_requests) self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "decode") assert decode_data is not None @@ -1775,23 +1771,26 @@ def execute_model( pad_to=logits_device.shape[0]) sampler_output = self.sampler( logits=logits_device, sampling_metadata=sampling_metadata) - decode_sampled_token_ids = \ - sampler_output.sampled_token_ids.flatten() - + if self.use_lookahead_decoding: + decode_sampled_token_ids = \ + sampler_output.sampled_token_ids.flatten() + else: + decode_sampled_token_ids.append( + sampler_output.sampled_token_ids.flatten()) + decode_sampled_requests.extend( + self.input_batch.req_ids[:original_num_decodes]) + htorch.core.mark_step() + if self.use_lookahead_decoding: for req_id, token_ids in zip( pd_info.decode_req_ids, decode_sampled_token_ids[:num_decodes].split(1)): if not self.is_chunked_prefill_dummy_output_token(req_id, prefill_sampled_requests, - pd_info.prompt_req_ids): + pd_info.prompt_req_ids): if not req_id in self.lookahead_tokens: self.lookahead_tokens[req_id] = [] self.lookahead_tokens[req_id].append(token_ids) - # TODO remove lookahead syncs with cpu - decode_sampled_requests.extend( - self.input_batch.req_ids[:(num_decodes-num_prefills)]) - htorch.core.mark_step() if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event self.profiler.end() @@ -1811,25 +1810,43 @@ def execute_model( # We already have tokens. Let's copy the data to # CPU as is, and then discard padded tokens. with self.profiler.record_event('internal', "sampler_postprocessing"): - if num_prefills > 0: - prefill_sampled_token_ids = torch.cat([ - tensor.cpu() for tensor in prefill_sampled_token_ids - ]).tolist() max_req_index = max(self.input_batch.req_id_to_index.values()) postprocessed_sampled_token_ids: list[list] postprocessed_sampled_token_ids = [[] - for _ in range(max_req_index + - 1)] - - for req_id in decode_sampled_requests: - req_index = self.input_batch.req_id_to_index[req_id] - tok_id = self.lookahead_tokens[req_id].pop(0).item() - postprocessed_sampled_token_ids[req_index].append(tok_id) - - for tok_id, req_id in zip(prefill_sampled_token_ids, - prefill_sampled_requests): - req_index = self.input_batch.req_id_to_index[req_id] - postprocessed_sampled_token_ids[req_index].append(tok_id) + for _ in range(max_req_index + + 1)] + if self.use_lookahead_decoding: + if num_prefills > 0: + prefill_sampled_token_ids = torch.cat([ + tensor.cpu() for tensor in prefill_sampled_token_ids + ]).tolist() + + for req_id in decode_sampled_requests: + req_index = self.input_batch.req_id_to_index[req_id] + tok_id = self.lookahead_tokens[req_id].pop(0).item() + postprocessed_sampled_token_ids[req_index].append(tok_id) + + for tok_id, req_id in zip(prefill_sampled_token_ids, + prefill_sampled_requests): + req_index = self.input_batch.req_id_to_index[req_id] + postprocessed_sampled_token_ids[req_index].append(tok_id) + else: + prefill_sampled_token_ids = [ + tensor.cpu() for tensor in prefill_sampled_token_ids + ] + decode_sampled_token_ids = [ + tensor.cpu()[:num_decodes] + for tensor in decode_sampled_token_ids + ] + sampled_token_ids_list = torch.cat( + decode_sampled_token_ids + prefill_sampled_token_ids).tolist() + sampled_token_requests = \ + decode_sampled_requests + prefill_sampled_requests + + for tok_id, req_id in zip(sampled_token_ids_list, + sampled_token_requests): + postprocessed_sampled_token_ids[ + self.input_batch.req_id_to_index[req_id]].append(tok_id) # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs From 509f100b3e3bc61e884ce82a46ae83fae0cf562e Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 19 Aug 2025 13:36:48 +0200 Subject: [PATCH 06/15] Update hpu_model_runner.py --- vllm_gaudi/v1/worker/hpu_model_runner.py | 30 ++++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d66e81b0..e5f441e9 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1891,21 +1891,21 @@ def execute_model( self.input_batch.req_ids[:num_decodes]) else: with self.profiler.record_event('internal', "sampler"): - sampling_metadata = self._prepare_sampling( - batch_changed, - pd_info.decode_req_ids, - pad_to=logits_device.shape[0]) - sampler_output = self.sampler( - logits=logits_device, sampling_metadata=sampling_metadata) - if self.use_lookahead_decoding: - decode_sampled_token_ids = \ - sampler_output.sampled_token_ids.flatten() - else: - decode_sampled_token_ids.append( - sampler_output.sampled_token_ids.flatten()) - decode_sampled_requests.extend( - self.input_batch.req_ids[:original_num_decodes]) - htorch.core.mark_step() + sampling_metadata = self._prepare_sampling( + batch_changed, + pd_info.decode_req_ids, + pad_to=logits_device.shape[0]) + sampler_output = self.sampler( + logits=logits_device, sampling_metadata=sampling_metadata) + if self.use_lookahead_decoding: + decode_sampled_token_ids = \ + sampler_output.sampled_token_ids.flatten() + else: + decode_sampled_token_ids.append( + sampler_output.sampled_token_ids.flatten()) + decode_sampled_requests.extend( + self.input_batch.req_ids[:original_num_decodes]) + htorch.core.mark_step() if self.use_lookahead_decoding: for req_id, token_ids in zip( pd_info.decode_req_ids, From 54dab88cdc4713d344500d773663fe06ead11b22 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 19 Aug 2025 13:41:32 +0200 Subject: [PATCH 07/15] Update features.py --- vllm_gaudi/extension/features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 68625910..8122b2af 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -71,7 +71,7 @@ def get_features(): Value('exponential_bucketing', True, env_var='VLLM_EXPONENTIAL_BUCKETING'), Value('linear_bucketing', True), Value('bucketing_strategy', FirstEnabled(*bucketing_strategies), env_var_type=choice(*bucketing_strategies)), - Value('lookahead_decoding', False, env_var='VLLM_USE_LOOKAHEAD_DECODING') + Value('lookahead_decoding', False, env_var='VLLM_USE_LOOKAHEAD_DECODING'), Value('regional_compilation', True, env_var='VLLM_T_COMPILE_REGIONAL_COMPILATION', env_var_type=boolean), Value('dynamic_shapes_compilation', False, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean), Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean), From e46a6ecfdd6ace850744ece6c102ebeacc7c0f55 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Wed, 20 Aug 2025 17:15:15 +0300 Subject: [PATCH 08/15] Remove unnecessary host syncs --- vllm_gaudi/v1/worker/hpu_input_batch.py | 9 +-------- vllm_gaudi/v1/worker/hpu_model_runner.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_input_batch.py b/vllm_gaudi/v1/worker/hpu_input_batch.py index 28dc1215..7852a3be 100644 --- a/vllm_gaudi/v1/worker/hpu_input_batch.py +++ b/vllm_gaudi/v1/worker/hpu_input_batch.py @@ -605,14 +605,7 @@ def make_selective_sampling_metadata( for req_id, _ in req_id_output_token_ids ] prompt_token_ids = None - if not skip_copy: - self.temperature[req_indices].copy_( - self.temperature_cpu_tensor[req_indices], non_blocking=True) - self.top_p[req_indices].copy_(self.top_p_cpu_tensor[req_indices], - non_blocking=True) - self.top_k[req_indices].copy_(self.top_k_cpu_tensor[req_indices], - non_blocking=True) - if not self.no_penalties: + if not skip_copy and not self.no_penalties: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index e5f441e9..6fcca9b9 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1427,7 +1427,7 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, # Create a mapping from request_id to newly generated token req_to_token = {} - #### ACCURACY THREAT ????????????? + tokens = [token for token in prefill_sampled_tokens if token.shape != torch.Size([0])] for token, req_id in zip(tokens, prefill_sampled_requests): req_to_token[req_id] = token @@ -1451,7 +1451,13 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, req_id = self.input_batch.req_ids[i] decode_data.token_ids[i] = self.lookahead_tokens.get(req_id, 0)[0] - return copy.deepcopy(decode_data) + return DecodeInputData( + num_decodes=decode_data.num_decodes, + token_ids=decode_data.token_ids.clone(), + position_ids=decode_data.position_ids.clone(), + logits_indices=decode_data.logits_indices, + attn_metadata=decode_data.attn_metadata, + ) def _prepare_inputs( self, From 686ef180cb51d57a17c13f3e53f16a651884072b Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Wed, 20 Aug 2025 18:56:06 +0300 Subject: [PATCH 09/15] Fix position_ids when using lookahead decoding --- vllm_gaudi/v1/worker/hpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 6fcca9b9..75a38bde 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1331,7 +1331,9 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, # Handle regular decodes for i in range(original_num_decodes): - positions[i, 0] = self.input_batch.num_computed_tokens_cpu[i] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] + positions[i, 0] = num_computed_tokens + 1 if self.use_lookahead_decoding \ + else num_computed_tokens # Handle lookahead decodes positions if self.use_lookahead_decoding: From 06df7941cecccb6e8b982a7f7ea6120fa8295c6e Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Thu, 21 Aug 2025 14:57:38 +0300 Subject: [PATCH 10/15] Remove delayed sampling mentions --- docs/features/supported_features.md | 1 - vllm_gaudi/extension/features.py | 1 - 2 files changed, 2 deletions(-) diff --git a/docs/features/supported_features.md b/docs/features/supported_features.md index 19fb4a2a..121cb5ea 100644 --- a/docs/features/supported_features.md +++ b/docs/features/supported_features.md @@ -27,5 +27,4 @@ title: Supported Features | Multinode support | vLLM HPU backend supports distributed, multiple-node inference with Ray. | | | vLLM v1 architecture (early release) | V1 architecture is now available for the HPU backend, and will gradually enable it for every use case we plan to support. | [Documentation](https://docs.vllm.ai/en/latest/serving/distributed_serving.html) | | Guided decode | vLLM HPU supports a guided decoding backend for generating structured outputs. | [Documentation](https://docs.vllm.ai/en/latest/features/structured_outputs.html) | -| Delayed Sampling (experimental) | vLLM HPU supports delayed sampling scheduling for asynchronous execution, enabled by `VLLM_DELAYED_SAMPLING=true` environment variable. | N/A | | Exponential bucketing | vLLM HPU supports exponential bucketing spacing instead of linear to automate configuration of bucketing mechanism, enabled by default. It can be disabled via `VLLM_EXPONENTIAL_BUCKETING=false` environment variable. | N/A | diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 8122b2af..6fd3ed01 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -66,7 +66,6 @@ def get_features(): Value('skip_warmup', False), Value('merged_prefill', False), Value('use_contiguous_pa', Disabled('prefix_caching'), env_var='VLLM_CONTIGUOUS_PA'), - Value('use_delayed_sampling', Engine('v0'), env_var='VLLM_DELAYED_SAMPLING'), Value('use_bucketing', True, env_var='VLLM_ENABLE_BUCKETING'), Value('exponential_bucketing', True, env_var='VLLM_EXPONENTIAL_BUCKETING'), Value('linear_bucketing', True), From a001384c722f7e2dbb39036bca9d8e6f9585cb1b Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Mon, 25 Aug 2025 22:22:59 +0300 Subject: [PATCH 11/15] Add block borrowing mechanism and fix sampler acc issues --- vllm_gaudi/v1/worker/hpu_model_runner.py | 69 ++++++++++++++++++++---- vllm_gaudi/v1/worker/hpu_worker.py | 11 ++-- 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 1ba05569..be373aac 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -32,6 +32,7 @@ from vllm.attention.selector import get_attn_backend from vllm.config import (VllmConfig, update_config) from vllm.forward_context import set_forward_context +from vllm.model_executor import set_random_seed from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import get_sampler @@ -681,11 +682,12 @@ def __init__( # High-level profiler self.profiler = HabanaHighLevelProfiler() self.profiler_counter_helper = HabanaProfilerCounterHelper() - # Lookahead decoding self.use_lookahead_decoding = get_config().lookahead_decoding # Storage for lookahead tokens that are computed but not yet scheduled self.lookahead_tokens: dict = {} + self.borrowed_blocks_mapping_fwd = {} + self.borrowed_blocks_mapping_bwd = {} self.defragmenter = OnlineDefragmenter() self.debug_fwd = init_debug_logger('fwd') @@ -1307,6 +1309,28 @@ def _get_prompt_bucketing_fn(self): else: return self._bucketize_2d_prompt + def borrow_block(self, block_ids_key): + # Borrow a block from the free block pool + for block_id, req_hash in self.borrowed_blocks_mapping_fwd.items(): + # Take first free block + if len(req_hash) == 0: + self.borrowed_blocks_mapping_fwd[block_id] = block_ids_key + self.borrowed_blocks_mapping_bwd[tuple(block_ids_key)] = block_id + return block_id + + def free_borrowed_block(self, block_ids): + # Free a borrowed block + block_ids_hash = block_ids[:-1] + new_block = block_ids[-1] + borrowed_block_id = self.borrowed_blocks_mapping_bwd[tuple(block_ids_hash)] + self.borrowed_blocks_mapping_fwd[borrowed_block_id] = [] + del self.borrowed_blocks_mapping_bwd[tuple(block_ids_hash)] + self.attn_backend.copy_blocks(self.kv_caches, torch.tensor([[borrowed_block_id * self.block_size, + new_block * self.block_size]], + dtype=torch.long, + device='hpu')) + + def _can_merge_prefill_contents(self, lhs, rhs): combined_num_tokens = lhs.get_num_tokens() + rhs.get_num_tokens() bucketing_fn = self._get_prompt_bucketing_fn() @@ -1564,16 +1588,15 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, # BLOCK_TABLE [batch, max_num_blocks_per_req] context_lens = [] - # Handle regular decodes for i in range(original_num_decodes): - context_lens.append(self.input_batch.num_computed_tokens_cpu[i]) + context_lens.append(self.input_batch.num_computed_tokens_cpu[i] + int(self.use_lookahead_decoding)) # Handle lookahead decodes if self.use_lookahead_decoding: # For lookahead decodes use the prompt length + 1 (the newly generated token) for i in range(original_num_decodes, num_decodes): - context_lens.append(self.input_batch.num_prompt_tokens[i] + 1) + context_lens.append(self.input_batch.num_prompt_tokens[i]) context_lens = np.array(context_lens) @@ -1591,6 +1614,12 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, block_tables_list = [] for i, n in enumerate(num_blocks): seq_block_table = block_table_cpu_tensor[i, :n].tolist() + if self.use_lookahead_decoding: + if seq_block_table[-1] == 0: + seq_block_table[-1] = self.borrow_block(seq_block_table[:-1]) + block_table_cpu_tensor[i, n-1] = seq_block_table[-1] + elif tuple(seq_block_table[:-1]) in self.borrowed_blocks_mapping_bwd.keys(): + self.free_borrowed_block(seq_block_table) assert len(seq_block_table) == n block_tables_list.append(seq_block_table) @@ -1720,10 +1749,10 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, )) def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, - prefill_sampled_tokens, prefill_sampled_requests): + prefill_sampled_tokens, prefill_sampled_requests, pd_info): """ Update decode_data for lookahead decoding to use newly generated tokens from prefill phase - instead of stale prompt tokens as inputs. + and lookahead stored tokens for originally scheduled decodes """ # Get the original decode count (before lookahead decodes) @@ -1749,11 +1778,19 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, # Update position to point to the correct location (after the generated token) prompt_len = self.input_batch.num_prompt_tokens[batch_idx] decode_data.position_ids[i, 0] = torch.tensor(prompt_len, dtype=torch.int, device=self.device) + # We need to update self.requests for preapre_sampling here + if not self.is_chunked_prefill_dummy_output_token(req_id, + prefill_sampled_requests, + pd_info.prompt_req_ids): + self.requests[req_id].output_token_ids.append(new_token.item()) # Replace tokens in regular decodes with lookahead stored tokens for i in range(0, original_num_decodes): req_id = self.input_batch.req_ids[i] - decode_data.token_ids[i] = self.lookahead_tokens.get(req_id, 0)[0] + token = self.lookahead_tokens.get(req_id, 0)[0] + decode_data.token_ids[i] = token + # We need to update self.requests for preapre_sampling here + self.requests[req_id].output_token_ids.append(token.item()) return DecodeInputData( num_decodes=decode_data.num_decodes, @@ -2205,6 +2242,7 @@ def execute_model( batch_changed, req_id, pad_to=logits_device.shape[0]) + set_random_seed(self.model_config.seed) sampler_output = self.sampler( logits=logits_device, sampling_metadata=sampling_metadata) @@ -2233,7 +2271,7 @@ def execute_model( if num_decodes > 0: if self.use_lookahead_decoding: decode_data = self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, - prefill_sampled_token_ids, prefill_sampled_requests) + prefill_sampled_token_ids, prefill_sampled_requests, pd_info) self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "decode") assert decode_data is not None @@ -2257,6 +2295,7 @@ def execute_model( batch_changed, pd_info.decode_req_ids, pad_to=logits_device.shape[0]) + set_random_seed(self.model_config.seed) sampler_output = self.sampler( logits=logits_device, sampling_metadata=sampling_metadata) if self.use_lookahead_decoding: @@ -2362,7 +2401,7 @@ def execute_model( ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### - for req_id in self.input_batch.req_ids[:num_reqs]: + for n, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): req_state = self.requests[req_id] i = self.input_batch.req_id_to_index[req_id] seq_len = (req_state.num_computed_tokens + @@ -2372,7 +2411,13 @@ def execute_model( self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids self.input_batch.num_tokens[i] += len(token_ids) - req_state.output_token_ids.extend(token_ids) + # With lookahead decoding output token ids for decodes were already updated in + # update_lookahead_decode_inputs() + if self.use_lookahead_decoding: + if n > num_decodes - 1: + req_state.output_token_ids.extend(token_ids) + else: + req_state.output_token_ids.extend(token_ids) # NOTE(chendi): enable cache based on PR(#20291) # Cache the sampled tokens in the model runner, so that the scheduler @@ -2402,7 +2447,6 @@ def execute_model( self.input_batch.num_tokens[req_idx] = end_idx req_id = self.input_batch.req_ids[req_idx] req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) ################## RETURN ################## # Create output. # Only return the originally scheduled requests, not the lookahead requests @@ -2933,6 +2977,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self._PAD_SLOT_ID = num_blocks * self.block_size htorch.hpu.synchronize() + if self.use_lookahead_decoding: + self.borrowed_blocks_mapping_fwd = {block_id : [] for block_id in range(num_blocks-(self.max_num_seqs+2), num_blocks-2)} + def get_supported_generation_tasks(self) -> list[GenerationTask]: model = self.get_model() diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index d5b14a3a..90d9ff21 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -221,20 +221,23 @@ def determine_available_memory(self) -> int: self.model_runner.mem_margin = hpu_memory_margin cache_size_bytes = available_hpu_memory * graph_headroom graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) - dummy_block_headroom = single_kv_block_size_bytes + num_dummy_blocks = 1 + if self.model_runner.use_lookahead_decoding: + num_dummy_blocks += self.scheduler_config.max_num_seqs + dummy_blocks_headroom = single_kv_block_size_bytes * num_dummy_blocks msg = ( f"Free device memory: {format_bytes(free_hpu_memory)}, " f"{format_bytes(available_hpu_memory)} usable " f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " - f"{format_bytes(dummy_block_headroom)} reserved for KV cache dummy " - f"block {format_bytes(cache_size_bytes-dummy_block_headroom)} " + f"{format_bytes(dummy_blocks_headroom)} reserved for KV cache dummy blocks" + f"block {format_bytes(cache_size_bytes-dummy_blocks_headroom)} " "reserved for usable KV cache") logger.info(msg) gc.collect() - return cache_size_bytes - dummy_block_headroom + return cache_size_bytes - dummy_blocks_headroom def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: From b45d758ec535d6449fc5bea3a58f36431d0f5a54 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Mon, 25 Aug 2025 21:33:09 +0200 Subject: [PATCH 12/15] Update hpu_model_runner.py --- vllm_gaudi/v1/worker/hpu_model_runner.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index be373aac..09b242b9 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import collections import contextlib -import copy import functools import itertools import math @@ -1330,7 +1329,6 @@ def free_borrowed_block(self, block_ids): dtype=torch.long, device='hpu')) - def _can_merge_prefill_contents(self, lhs, rhs): combined_num_tokens = lhs.get_num_tokens() + rhs.get_num_tokens() bucketing_fn = self._get_prompt_bucketing_fn() @@ -2134,6 +2132,7 @@ def execute_model( # Transfer [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] to CPU # On CPU, sanitize [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] -> [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] # noqa # Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] + if self.defragmenter.enabled and self.kv_caches: new = { req.req_id: flatten(req.block_ids) @@ -2150,6 +2149,7 @@ def execute_model( self.defragmenter.update_state(new | cached, scheduler_output.finished_req_ids) self.defragmenter.defragment() + batch_changed = self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOuptut if there's no work to do. @@ -2171,8 +2171,8 @@ def execute_model( # later. prefill_sampled_token_ids = [] prefill_sampled_requests = [] - decode_sampled_requests = [] decode_sampled_token_ids = [] + decode_sampled_requests = [] # NOTE(tianmu-li): For structured output, combine logits before # postprocessing. Should it be done for all requests? structured_output = False @@ -2180,7 +2180,7 @@ def execute_model( logits_prompt = [] logits_decode = [] structured_output = True - + ######################### PREFILLS ######################### if num_prefills > 0: htorch.core.mark_step() @@ -2399,7 +2399,6 @@ def execute_model( # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs - ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### for n, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): req_state = self.requests[req_id] From b4f08521d133535ef969429e33f246de9336ad02 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 26 Aug 2025 11:00:24 +0300 Subject: [PATCH 13/15] Disable lookahead warmup, remove .item() calls --- vllm_gaudi/v1/worker/hpu_model_runner.py | 58 +++++++++++++++--------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 09b242b9..a1388264 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -684,6 +684,7 @@ def __init__( # Lookahead decoding self.use_lookahead_decoding = get_config().lookahead_decoding # Storage for lookahead tokens that are computed but not yet scheduled + self.lookahead_tokens_tensors: dict = {} self.lookahead_tokens: dict = {} self.borrowed_blocks_mapping_fwd = {} self.borrowed_blocks_mapping_bwd = {} @@ -746,6 +747,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if self.use_lookahead_decoding: # Pop stored lookahead tokens for finished requests - dummy tokens at the end self.lookahead_tokens.pop(req_id, None) + self.lookahead_tokens_tensors.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -1747,7 +1749,7 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, )) def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, - prefill_sampled_tokens, prefill_sampled_requests, pd_info): + prefill_sampled_tokens, prefill_sampled_requests, pd_info, sampling_preparation): """ Update decode_data for lookahead decoding to use newly generated tokens from prefill phase and lookahead stored tokens for originally scheduled decodes @@ -1772,23 +1774,30 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, if req_id in req_to_token: # Update the decode input to use the newly generated token new_token = req_to_token[req_id] - decode_data.token_ids[i] = new_token - # Update position to point to the correct location (after the generated token) - prompt_len = self.input_batch.num_prompt_tokens[batch_idx] - decode_data.position_ids[i, 0] = torch.tensor(prompt_len, dtype=torch.int, device=self.device) - # We need to update self.requests for preapre_sampling here - if not self.is_chunked_prefill_dummy_output_token(req_id, - prefill_sampled_requests, - pd_info.prompt_req_ids): - self.requests[req_id].output_token_ids.append(new_token.item()) + # We need to update self.requests for preapre_sampling here + if sampling_preparation: + if not self.is_chunked_prefill_dummy_output_token(req_id, + prefill_sampled_requests, + pd_info.prompt_req_ids): + self.requests[req_id].output_token_ids.append(new_token.cpu().tolist()) + else: + decode_data.token_ids[i] = new_token + # Update position to point to the correct location (after the generated token) + prompt_len = self.input_batch.num_prompt_tokens[batch_idx] + decode_data.position_ids[i, 0] = torch.tensor(prompt_len, dtype=torch.int, device=self.device) # Replace tokens in regular decodes with lookahead stored tokens for i in range(0, original_num_decodes): req_id = self.input_batch.req_ids[i] - token = self.lookahead_tokens.get(req_id, 0)[0] - decode_data.token_ids[i] = token # We need to update self.requests for preapre_sampling here - self.requests[req_id].output_token_ids.append(token.item()) + if sampling_preparation: + token = self.lookahead_tokens_tensors.get(req_id, 0)[0].item() + if not req_id in self.lookahead_tokens: + self.lookahead_tokens[req_id] = [] + self.lookahead_tokens[req_id].append(token) + self.requests[req_id].output_token_ids.append(token) + else: + decode_data.token_ids[i] = self.lookahead_tokens_tensors.get(req_id, 0)[0] return DecodeInputData( num_decodes=decode_data.num_decodes, @@ -1796,7 +1805,7 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, position_ids=decode_data.position_ids.clone(), logits_indices=decode_data.logits_indices, attn_metadata=decode_data.attn_metadata, - ) + ) if not sampling_preparation else None def _prepare_inputs( self, @@ -2242,7 +2251,6 @@ def execute_model( batch_changed, req_id, pad_to=logits_device.shape[0]) - set_random_seed(self.model_config.seed) sampler_output = self.sampler( logits=logits_device, sampling_metadata=sampling_metadata) @@ -2271,7 +2279,7 @@ def execute_model( if num_decodes > 0: if self.use_lookahead_decoding: decode_data = self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, - prefill_sampled_token_ids, prefill_sampled_requests, pd_info) + prefill_sampled_token_ids, prefill_sampled_requests, pd_info, False) self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "decode") assert decode_data is not None @@ -2291,11 +2299,14 @@ def execute_model( self.input_batch.req_ids[:num_decodes]) else: with self.profiler.record_event('internal', "sampler"): + if self.use_lookahead_decoding: + _ = self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, + prefill_sampled_token_ids, prefill_sampled_requests, pd_info, True) sampling_metadata = self._prepare_sampling( batch_changed, pd_info.decode_req_ids, pad_to=logits_device.shape[0]) - set_random_seed(self.model_config.seed) + torch.manual_seed(self.model_config.seed) sampler_output = self.sampler( logits=logits_device, sampling_metadata=sampling_metadata) if self.use_lookahead_decoding: @@ -2314,9 +2325,9 @@ def execute_model( if not self.is_chunked_prefill_dummy_output_token(req_id, prefill_sampled_requests, pd_info.prompt_req_ids): - if not req_id in self.lookahead_tokens: - self.lookahead_tokens[req_id] = [] - self.lookahead_tokens[req_id].append(token_ids) + if not req_id in self.lookahead_tokens_tensors: + self.lookahead_tokens_tensors[req_id] = [] + self.lookahead_tokens_tensors[req_id].append(token_ids) if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event @@ -2372,7 +2383,8 @@ def execute_model( for req_id in decode_sampled_requests: req_index = self.input_batch.req_id_to_index[req_id] - tok_id = self.lookahead_tokens[req_id].pop(0).item() + tok_id = self.lookahead_tokens[req_id].pop(0) + _ = self.lookahead_tokens_tensors[req_id].pop(0) postprocessed_sampled_token_ids[req_index].append(tok_id) for tok_id, req_id in zip(prefill_sampled_token_ids, @@ -2410,7 +2422,7 @@ def execute_model( self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids self.input_batch.num_tokens[i] += len(token_ids) - # With lookahead decoding output token ids for decodes were already updated in + # With lookahead decoding output token ids for decodes were already updated in # update_lookahead_decode_inputs() if self.use_lookahead_decoding: if n > num_decodes - 1: @@ -2821,6 +2833,7 @@ def warmup_model(self) -> None: logger.info("Skipping warmup...") return + self.use_lookahead_decoding = False self.profiler.start('internal', 'warmup') start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() @@ -2863,6 +2876,7 @@ def warmup_model(self) -> None: end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() + self.use_lookahead_decoding = get_config().lookahead_decoding if os.getenv('VLLM_FULL_WARMUP', 'false').strip().lower() in ("1", "true"): # Since the model is warmed up for all possible tensor sizes, From 1ab4f86c9cad28f9bfcc37b08d621695eb963b6d Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 26 Aug 2025 16:52:04 +0300 Subject: [PATCH 14/15] Fix accuracy issue with block borrowing --- vllm_gaudi/v1/worker/hpu_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index a1388264..33c7f910 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -31,7 +31,6 @@ from vllm.attention.selector import get_attn_backend from vllm.config import (VllmConfig, update_config) from vllm.forward_context import set_forward_context -from vllm.model_executor import set_random_seed from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import get_sampler @@ -1615,7 +1614,7 @@ def _prepare_decode_inputs(self, num_prefills, num_decodes, for i, n in enumerate(num_blocks): seq_block_table = block_table_cpu_tensor[i, :n].tolist() if self.use_lookahead_decoding: - if seq_block_table[-1] == 0: + if (context_lens[i] + 1) % self.block_size == 1: seq_block_table[-1] = self.borrow_block(seq_block_table[:-1]) block_table_cpu_tensor[i, n-1] = seq_block_table[-1] elif tuple(seq_block_table[:-1]) in self.borrowed_blocks_mapping_bwd.keys(): @@ -2306,7 +2305,6 @@ def execute_model( batch_changed, pd_info.decode_req_ids, pad_to=logits_device.shape[0]) - torch.manual_seed(self.model_config.seed) sampler_output = self.sampler( logits=logits_device, sampling_metadata=sampling_metadata) if self.use_lookahead_decoding: From aaf60d7445c8e11971f8fcb7028c5cdbc7677067 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Wed, 27 Aug 2025 16:14:57 +0300 Subject: [PATCH 15/15] Remove recompilations part 1 --- vllm_gaudi/v1/worker/hpu_model_runner.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index bf04fed6..9f7f097f 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1774,9 +1774,6 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, self.requests[req_id].output_token_ids.append(new_token.cpu().tolist()) else: decode_data.token_ids[i] = new_token - # Update position to point to the correct location (after the generated token) - prompt_len = self.input_batch.num_prompt_tokens[batch_idx] - decode_data.position_ids[i, 0] = torch.tensor(prompt_len, dtype=torch.int, device=self.device) # Replace tokens in regular decodes with lookahead stored tokens for i in range(0, original_num_decodes): @@ -1794,7 +1791,7 @@ def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, return DecodeInputData( num_decodes=decode_data.num_decodes, token_ids=decode_data.token_ids.clone(), - position_ids=decode_data.position_ids.clone(), + position_ids=decode_data.position_ids, logits_indices=decode_data.logits_indices, attn_metadata=decode_data.attn_metadata, ) if not sampling_preparation else None @@ -2310,12 +2307,12 @@ def execute_model( self.input_batch.req_ids[:original_num_decodes]) htorch.core.mark_step() if self.use_lookahead_decoding: - for req_id, token_ids in zip( + for i, (req_id, token_ids) in enumerate(zip( pd_info.decode_req_ids, - decode_sampled_token_ids[:num_decodes].split(1)): + decode_sampled_token_ids.split(1))): if not self.is_chunked_prefill_dummy_output_token(req_id, prefill_sampled_requests, - pd_info.prompt_req_ids): + pd_info.prompt_req_ids) and i < num_decodes: if not req_id in self.lookahead_tokens_tensors: self.lookahead_tokens_tensors[req_id] = [] self.lookahead_tokens_tensors[req_id].append(token_ids)