diff --git a/docs/features/supported_features.md b/docs/features/supported_features.md index 19fb4a2a..121cb5ea 100644 --- a/docs/features/supported_features.md +++ b/docs/features/supported_features.md @@ -27,5 +27,4 @@ title: Supported Features | Multinode support | vLLM HPU backend supports distributed, multiple-node inference with Ray. | | | vLLM v1 architecture (early release) | V1 architecture is now available for the HPU backend, and will gradually enable it for every use case we plan to support. | [Documentation](https://docs.vllm.ai/en/latest/serving/distributed_serving.html) | | Guided decode | vLLM HPU supports a guided decoding backend for generating structured outputs. | [Documentation](https://docs.vllm.ai/en/latest/features/structured_outputs.html) | -| Delayed Sampling (experimental) | vLLM HPU supports delayed sampling scheduling for asynchronous execution, enabled by `VLLM_DELAYED_SAMPLING=true` environment variable. | N/A | | Exponential bucketing | vLLM HPU supports exponential bucketing spacing instead of linear to automate configuration of bucketing mechanism, enabled by default. It can be disabled via `VLLM_EXPONENTIAL_BUCKETING=false` environment variable. | N/A | diff --git a/vllm_gaudi/envs.py b/vllm_gaudi/envs.py index b2a82b3a..051e79ad 100644 --- a/vllm_gaudi/envs.py +++ b/vllm_gaudi/envs.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True - VLLM_HPU_USE_DELAYED_SAMPLING: bool = False VLLM_HPU_FORCE_CHANNEL_FP8: bool = True # The begin-* and end* here are used by the documentation generator @@ -20,12 +19,6 @@ lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in ("1", "true"), - # Use delayed sampling for HPU to reduce host cpu overhead - # between each step. - "VLLM_HPU_USE_DELAYED_SAMPLING": - lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in - ("1", "true"), - # Convert block fp8 to channel fp8 for HPU "VLLM_HPU_FORCE_CHANNEL_FP8": lambda: os.environ.get("VLLM_HPU_FORCE_CHANNEL_FP8", "true").lower() in diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 7a6eb45d..2a1a399e 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -71,14 +71,15 @@ def get_features(): Value('skip_warmup', False), Value('merged_prefill', False), Value('use_contiguous_pa', Disabled('prefix_caching'), env_var='VLLM_CONTIGUOUS_PA'), - Value('use_delayed_sampling', Engine('v0'), env_var='VLLM_DELAYED_SAMPLING'), Value('use_bucketing', True, env_var='VLLM_ENABLE_BUCKETING'), Value('exponential_bucketing', True), Value('linear_bucketing', True), + Value('lookahead_decoding', False, env_var='VLLM_USE_LOOKAHEAD_DECODING'), ValueFromList('bucketing_strategy', bucketing_strategies), Value('defrag', False), Value('regional_compilation', True, env_var='VLLM_T_COMPILE_REGIONAL_COMPILATION', env_var_type=boolean), Value('dynamic_shapes_compilation', True, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean), Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean), + ] return split_values_and_flags(features) diff --git a/vllm_gaudi/v1/worker/hpu_input_batch.py b/vllm_gaudi/v1/worker/hpu_input_batch.py index 14922a91..8666dd30 100644 --- a/vllm_gaudi/v1/worker/hpu_input_batch.py +++ b/vllm_gaudi/v1/worker/hpu_input_batch.py @@ -605,14 +605,7 @@ def make_selective_sampling_metadata( for req_id, _ in req_id_output_token_ids ] prompt_token_ids = None - if not skip_copy: - self.temperature[req_indices].copy_( - self.temperature_cpu_tensor[req_indices], non_blocking=True) - self.top_p[req_indices].copy_(self.top_p_cpu_tensor[req_indices], - non_blocking=True) - self.top_k[req_indices].copy_(self.top_k_cpu_tensor[req_indices], - non_blocking=True) - if not self.no_penalties: + if not skip_copy and not self.no_penalties: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index da1cab1a..9f7f097f 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -667,6 +667,13 @@ def __init__( # High-level profiler self.profiler = HabanaHighLevelProfiler() self.profiler_counter_helper = HabanaProfilerCounterHelper() + # Lookahead decoding + self.use_lookahead_decoding = get_config().lookahead_decoding + # Storage for lookahead tokens that are computed but not yet scheduled + self.lookahead_tokens_tensors: dict = {} + self.lookahead_tokens: dict = {} + self.borrowed_blocks_mapping_fwd = {} + self.borrowed_blocks_mapping_bwd = {} self.defragmenter = OnlineDefragmenter() self.debug_fwd = init_debug_logger('fwd') @@ -723,6 +730,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + if self.use_lookahead_decoding: + # Pop stored lookahead tokens for finished requests - dummy tokens at the end + self.lookahead_tokens.pop(req_id, None) + self.lookahead_tokens_tensors.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -1126,6 +1137,7 @@ def _get_prompts_and_decodes( # Traverse prompts prompt_req_ids = [] prompt_scheduled_tokens = [] + lookahead_decode_req_ids = [] for i in range(len(decode_req_ids), num_reqs): req_id = self.input_batch.req_ids[i] assert req_id is not None @@ -1143,8 +1155,14 @@ def _get_prompts_and_decodes( prompt_req_ids.append(req_id) prompt_scheduled_tokens.append(num_scheduled_tokens) - - return PromptDecodeInfo(prompt_req_ids, decode_req_ids, + + # Schedule lookahead decode for this prompt + if self.use_lookahead_decoding: + lookahead_decode_req_ids.append(req_id) + + all_decode_req_ids = decode_req_ids + lookahead_decode_req_ids + + return PromptDecodeInfo(prompt_req_ids, all_decode_req_ids, prompt_scheduled_tokens) def _prepare_sampling(self, @@ -1284,6 +1302,27 @@ def _get_prompt_bucketing_fn(self): else: return self._bucketize_2d_prompt + def borrow_block(self, block_ids_key): + # Borrow a block from the free block pool + for block_id, req_hash in self.borrowed_blocks_mapping_fwd.items(): + # Take first free block + if len(req_hash) == 0: + self.borrowed_blocks_mapping_fwd[block_id] = block_ids_key + self.borrowed_blocks_mapping_bwd[tuple(block_ids_key)] = block_id + return block_id + + def free_borrowed_block(self, block_ids): + # Free a borrowed block + block_ids_hash = block_ids[:-1] + new_block = block_ids[-1] + borrowed_block_id = self.borrowed_blocks_mapping_bwd[tuple(block_ids_hash)] + self.borrowed_blocks_mapping_fwd[borrowed_block_id] = [] + del self.borrowed_blocks_mapping_bwd[tuple(block_ids_hash)] + self.attn_backend.copy_blocks(self.kv_caches, torch.tensor([[borrowed_block_id * self.block_size, + new_block * self.block_size]], + dtype=torch.long, + device='hpu')) + def _can_merge_prefill_contents(self, lhs, rhs): combined_num_tokens = lhs.get_num_tokens() + rhs.get_num_tokens() bucketing_fn = self._get_prompt_bucketing_fn() @@ -1301,15 +1340,26 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_scheduled_tokens): # DECODES are the first num_decodes REQUESTS. # PREFILLS are the next num_reqs - num_decodes REQUESTS. - num_reqs = num_prefills + num_decodes + if self.use_lookahead_decoding: + # When lookahead decoding is used, num_decodes includes lookahead decodes + # but we need to find where actual prefills start in the batch + actual_num_decodes = num_decodes - num_prefills + idx_bias = num_prefills + else: + actual_num_decodes = num_decodes + idx_bias = 0 block_table_cpu_tensor = self.input_batch.block_table[ 0].get_cpu_tensor() all_batch_contents = [BatchContents()] - for batch_idx in range(num_decodes, num_reqs): + # Prefills start after the actual decode requests in the input_batch + prefill_start = actual_num_decodes + prefill_end = prefill_start + num_prefills + + for batch_idx in range(prefill_start, prefill_end): req_id = self.input_batch.req_ids[batch_idx] context_len = self.input_batch.num_computed_tokens_cpu[batch_idx] - query_len = num_scheduled_tokens[batch_idx] + query_len = num_scheduled_tokens[batch_idx+idx_bias] token_ids = self.input_batch.token_ids_cpu[ batch_idx, context_len:context_len + query_len].tolist() @@ -1509,7 +1559,7 @@ def _prepare_prefill_inputs( merge_contents(all_batches[0], *all_batches[1:]) return all_batches[0] - def _prepare_decode_inputs(self, num_decodes, + def _prepare_decode_inputs(self, num_prefills, num_decodes, num_scheduled_tokens) -> DecodeInputData: # Decodes run as one single padded batch with shape [batch, 1] # @@ -1522,8 +1572,25 @@ def _prepare_decode_inputs(self, num_decodes, 0].get_cpu_tensor() if num_decodes == 0: return DecodeInputData(num_decodes=0) + + # Calculate original decodes count (before lookahead) + original_num_decodes = num_decodes + if self.use_lookahead_decoding: + original_num_decodes = num_decodes - num_prefills + # BLOCK_TABLE [batch, max_num_blocks_per_req] - context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes] + context_lens = [] + # Handle regular decodes + for i in range(original_num_decodes): + context_lens.append(self.input_batch.num_computed_tokens_cpu[i] + int(self.use_lookahead_decoding)) + + # Handle lookahead decodes + if self.use_lookahead_decoding: + # For lookahead decodes use the prompt length + 1 (the newly generated token) + for i in range(original_num_decodes, num_decodes): + context_lens.append(self.input_batch.num_prompt_tokens[i]) + + context_lens = np.array(context_lens) # NOTE(kzawora): the +1 is what causes this entire thing to work, # as in the paged attention, we don't fetch just the context from cache, @@ -1539,17 +1606,32 @@ def _prepare_decode_inputs(self, num_decodes, block_tables_list = [] for i, n in enumerate(num_blocks): seq_block_table = block_table_cpu_tensor[i, :n].tolist() + if self.use_lookahead_decoding: + if (context_lens[i] + 1) % self.block_size == 1: + seq_block_table[-1] = self.borrow_block(seq_block_table[:-1]) + block_table_cpu_tensor[i, n-1] = seq_block_table[-1] + elif tuple(seq_block_table[:-1]) in self.borrowed_blocks_mapping_bwd.keys(): + self.free_borrowed_block(seq_block_table) assert len(seq_block_table) == n block_tables_list.append(seq_block_table) # POSITIONS. [batch, 1] # We slice at the end, since we use the positions for gathering. positions = torch.zeros((padded_batch_size, 1), dtype=torch.int32) - positions[:num_decodes] = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu.reshape(-1, - 1)[:num_decodes]) - positions = positions[:padded_batch_size] + + # Handle regular decodes + for i in range(original_num_decodes): + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] + positions[i, 0] = num_computed_tokens + 1 if self.use_lookahead_decoding \ + else num_computed_tokens + + # Handle lookahead decodes positions + if self.use_lookahead_decoding: + for i in range(original_num_decodes, num_decodes): + # Position should be at the newly generated token (prompt_len) + positions[i, 0] = self.input_batch.num_prompt_tokens[i] + positions = positions[:padded_batch_size] padded_index = torch.zeros((padded_batch_size, 1), dtype=torch.int64) index = positions.to(torch.int64)[:num_decodes] padded_index[:num_decodes] = index @@ -1658,6 +1740,62 @@ def _prepare_decode_inputs(self, num_decodes, block_size=self.block_size, )) + def update_lookahead_decode_inputs(self, decode_data, num_prefills, num_decodes, + prefill_sampled_tokens, prefill_sampled_requests, pd_info, sampling_preparation): + """ + Update decode_data for lookahead decoding to use newly generated tokens from prefill phase + and lookahead stored tokens for originally scheduled decodes + """ + + # Get the original decode count (before lookahead decodes) + original_num_decodes = num_decodes - num_prefills + + # Create a mapping from request_id to newly generated token + req_to_token = {} + + tokens = [token for token in prefill_sampled_tokens if token.shape != torch.Size([0])] + for token, req_id in zip(tokens, prefill_sampled_requests): + req_to_token[req_id] = token + + # For lookahead decodes (which start after original decodes), update their input tokens + for i in range(original_num_decodes, num_decodes): + # Map lookahead decode index to corresponding prefill request + prefill_idx = i - original_num_decodes + batch_idx = original_num_decodes + prefill_idx + req_id = self.input_batch.req_ids[batch_idx] + if req_id in req_to_token: + # Update the decode input to use the newly generated token + new_token = req_to_token[req_id] + # We need to update self.requests for preapre_sampling here + if sampling_preparation: + if not self.is_chunked_prefill_dummy_output_token(req_id, + prefill_sampled_requests, + pd_info.prompt_req_ids): + self.requests[req_id].output_token_ids.append(new_token.cpu().tolist()) + else: + decode_data.token_ids[i] = new_token + + # Replace tokens in regular decodes with lookahead stored tokens + for i in range(0, original_num_decodes): + req_id = self.input_batch.req_ids[i] + # We need to update self.requests for preapre_sampling here + if sampling_preparation: + token = self.lookahead_tokens_tensors.get(req_id, 0)[0].item() + if not req_id in self.lookahead_tokens: + self.lookahead_tokens[req_id] = [] + self.lookahead_tokens[req_id].append(token) + self.requests[req_id].output_token_ids.append(token) + else: + decode_data.token_ids[i] = self.lookahead_tokens_tensors.get(req_id, 0)[0] + + return DecodeInputData( + num_decodes=decode_data.num_decodes, + token_ids=decode_data.token_ids.clone(), + position_ids=decode_data.position_ids, + logits_indices=decode_data.logits_indices, + attn_metadata=decode_data.attn_metadata, + ) if not sampling_preparation else None + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1670,6 +1808,10 @@ def _prepare_inputs( num_reqs = num_prefills + num_decodes + actual_num_decodes = num_decodes + if self.use_lookahead_decoding: + actual_num_decodes = num_decodes - num_prefills + # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] @@ -1682,11 +1824,17 @@ def _prepare_inputs( num_scheduled_tokens.append(seq_num_scheduled_tokens) num_prompt_tokens.append(seq_num_prompt_tokens) # NOTE: assert that all the decodes are "decodes". - if idx < num_decodes: + if idx < actual_num_decodes: assert seq_num_scheduled_tokens == 1 + + if self.use_lookahead_decoding: + # Insert scheduled tokens for lookahead decodes (always 1 token per decode) + for _ in range(num_prefills): + num_scheduled_tokens.insert(0, 1) return (self._prepare_prefill_inputs(num_prefills, num_decodes, num_scheduled_tokens), - self._prepare_decode_inputs(num_decodes, num_scheduled_tokens)) + self._prepare_decode_inputs(num_prefills, num_decodes, + num_scheduled_tokens)) def _seq_len(self, attn_metadata): return attn_metadata.slot_mapping.size(-1) @@ -1805,7 +1953,7 @@ def _get_prompt_logprobs_dict( # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want - # to gather the logprob for. + # to gather the logprobs for. tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] # Compute prompt logprobs. @@ -1834,6 +1982,13 @@ def _is_quant_with_inc(self): quant_config = os.getenv("QUANT_CONFIG", None) is not None return (self.model_config.quantization == "inc" or quant_config) + def is_chunked_prefill_dummy_output_token( + self, + req_id, + prefill_sampled_requests, + prompt_req_ids) -> bool: + return (req_id in prompt_req_ids) and (req_id not in prefill_sampled_requests) + # Copied from vllm/v1/worker/gpu_model_runner.py def apply_grammar_bitmask( self, @@ -1992,7 +2147,7 @@ def execute_model( self.defragmenter.update_state(new | cached, scheduler_output.finished_req_ids) self.defragmenter.defragment() - + batch_changed = self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOuptut if there's no work to do. @@ -2003,7 +2158,10 @@ def execute_model( pd_info = self._get_prompts_and_decodes(scheduler_output) num_decodes = len(pd_info.decode_req_ids) num_prefills = len(pd_info.prompt_req_ids) - num_reqs = num_decodes + num_prefills + original_num_decodes = num_decodes + if self.use_lookahead_decoding: + original_num_decodes = num_decodes - num_prefills + num_reqs = original_num_decodes + num_prefills with self.profiler.record_event('internal', 'prepare_input_tensors'): prefill_data, decode_data = self._prepare_inputs( scheduler_output, num_prefills, num_decodes) @@ -2108,6 +2266,9 @@ def execute_model( ######################### DECODES ######################### # Decodes run as one single batch with [padded_decode_bs, 1] if num_decodes > 0: + if self.use_lookahead_decoding: + decode_data = self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, + prefill_sampled_token_ids, prefill_sampled_requests, pd_info, False) self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "decode") assert decode_data is not None @@ -2127,18 +2288,35 @@ def execute_model( self.input_batch.req_ids[:num_decodes]) else: with self.profiler.record_event('internal', "sampler"): + if self.use_lookahead_decoding: + _ = self.update_lookahead_decode_inputs(decode_data, num_prefills, num_decodes, + prefill_sampled_token_ids, prefill_sampled_requests, pd_info, True) sampling_metadata = self._prepare_sampling( batch_changed, pd_info.decode_req_ids, pad_to=logits_device.shape[0]) sampler_output = self.sampler( - logits=logits_device, - sampling_metadata=sampling_metadata) - decode_sampled_token_ids.append( - sampler_output.sampled_token_ids.flatten()) + logits=logits_device, sampling_metadata=sampling_metadata) + if self.use_lookahead_decoding: + decode_sampled_token_ids = \ + sampler_output.sampled_token_ids.flatten() + else: + decode_sampled_token_ids.append( + sampler_output.sampled_token_ids.flatten()) decode_sampled_requests.extend( - self.input_batch.req_ids[:num_decodes]) + self.input_batch.req_ids[:original_num_decodes]) htorch.core.mark_step() + if self.use_lookahead_decoding: + for i, (req_id, token_ids) in enumerate(zip( + pd_info.decode_req_ids, + decode_sampled_token_ids.split(1))): + if not self.is_chunked_prefill_dummy_output_token(req_id, + prefill_sampled_requests, + pd_info.prompt_req_ids) and i < num_decodes: + if not req_id in self.lookahead_tokens_tensors: + self.lookahead_tokens_tensors[req_id] = [] + self.lookahead_tokens_tensors[req_id].append(token_ids) + if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event self.profiler.end() @@ -2180,31 +2358,49 @@ def execute_model( # We already have tokens. Let's copy the data to # CPU as is, and then discard padded tokens. with self.profiler.record_event('internal', "sampler_postprocessing"): - prefill_sampled_token_ids = [ - tensor.cpu() for tensor in prefill_sampled_token_ids - ] - decode_sampled_token_ids = [ - tensor.cpu()[:num_decodes] - for tensor in decode_sampled_token_ids - ] - sampled_token_ids_list = torch.cat( - decode_sampled_token_ids + prefill_sampled_token_ids).tolist() - sampled_token_requests = \ - decode_sampled_requests + prefill_sampled_requests max_req_index = max(self.input_batch.req_id_to_index.values()) postprocessed_sampled_token_ids: list[list] postprocessed_sampled_token_ids = [[] - for _ in range(max_req_index + - 1)] - for tok_id, req_id in zip(sampled_token_ids_list, - sampled_token_requests): - postprocessed_sampled_token_ids[ - self.input_batch.req_id_to_index[req_id]].append(tok_id) - + for _ in range(max_req_index + + 1)] + if self.use_lookahead_decoding: + if num_prefills > 0: + prefill_sampled_token_ids = torch.cat([ + tensor.cpu() for tensor in prefill_sampled_token_ids + ]).tolist() + + for req_id in decode_sampled_requests: + req_index = self.input_batch.req_id_to_index[req_id] + tok_id = self.lookahead_tokens[req_id].pop(0) + _ = self.lookahead_tokens_tensors[req_id].pop(0) + postprocessed_sampled_token_ids[req_index].append(tok_id) + + for tok_id, req_id in zip(prefill_sampled_token_ids, + prefill_sampled_requests): + req_index = self.input_batch.req_id_to_index[req_id] + postprocessed_sampled_token_ids[req_index].append(tok_id) + else: + prefill_sampled_token_ids = [ + tensor.cpu() for tensor in prefill_sampled_token_ids + ] + decode_sampled_token_ids = [ + tensor.cpu()[:num_decodes] + for tensor in decode_sampled_token_ids + ] + sampled_token_ids_list = torch.cat( + decode_sampled_token_ids + prefill_sampled_token_ids).tolist() + sampled_token_requests = \ + decode_sampled_requests + prefill_sampled_requests + + for tok_id, req_id in zip(sampled_token_ids_list, + sampled_token_requests): + postprocessed_sampled_token_ids[ + self.input_batch.req_id_to_index[req_id]].append(tok_id) + # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### - for req_id in self.input_batch.req_ids[:num_reqs]: + for n, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): req_state = self.requests[req_id] i = self.input_batch.req_id_to_index[req_id] seq_len = (req_state.num_computed_tokens + @@ -2214,7 +2410,13 @@ def execute_model( self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids self.input_batch.num_tokens[i] += len(token_ids) - req_state.output_token_ids.extend(token_ids) + # With lookahead decoding output token ids for decodes were already updated in + # update_lookahead_decode_inputs() + if self.use_lookahead_decoding: + if n > num_decodes - 1: + req_state.output_token_ids.extend(token_ids) + else: + req_state.output_token_ids.extend(token_ids) # NOTE(chendi): enable cache based on PR(#20291) # Cache the sampled tokens in the model runner, so that the scheduler @@ -2244,15 +2446,18 @@ def execute_model( self.input_batch.num_tokens[req_idx] = end_idx req_id = self.input_batch.req_ids[req_idx] req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) ################## RETURN ################## # Create output. - all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids - # prompt_logprobs_dict: dict[ + # Only return the originally scheduled requests, not the lookahead requests + original_decode_req_ids = pd_info.decode_req_ids + if self.use_lookahead_decoding: + original_decode_req_ids = pd_info.decode_req_ids[:-num_prefills] if num_prefills > 0 else pd_info.decode_req_ids + + all_req_ids = original_decode_req_ids + pd_info.prompt_req_ids + #prompt_logprobs_dict: dict[ # str, Optional[LogprobsTensors]] = self._get_prompt_logprobs_dict( # prefill_hidden_states_device, scheduler_output) prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} - all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids logprobs = None model_runner_output = ModelRunnerOutput( @@ -2616,6 +2821,7 @@ def warmup_model(self) -> None: logger.info("Skipping warmup...") return + self.use_lookahead_decoding = False self.profiler.start('internal', 'warmup') start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() @@ -2658,6 +2864,7 @@ def warmup_model(self) -> None: end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() + self.use_lookahead_decoding = get_config().lookahead_decoding if os.getenv('VLLM_FULL_WARMUP', 'false').strip().lower() in ("1", "true"): # Since the model is warmed up for all possible tensor sizes, @@ -2771,6 +2978,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self._PAD_SLOT_ID = num_blocks * self.block_size htorch.hpu.synchronize() + if self.use_lookahead_decoding: + self.borrowed_blocks_mapping_fwd = {block_id : [] for block_id in range(num_blocks-(self.max_num_seqs+2), num_blocks-2)} + def get_supported_generation_tasks(self) -> list[GenerationTask]: model = self.get_model() diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index 3a1de664..9b236915 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -221,20 +221,23 @@ def determine_available_memory(self) -> int: self.model_runner.mem_margin = hpu_memory_margin cache_size_bytes = available_hpu_memory * graph_headroom graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) - dummy_block_headroom = single_kv_block_size_bytes + num_dummy_blocks = 1 + if self.model_runner.use_lookahead_decoding: + num_dummy_blocks += self.scheduler_config.max_num_seqs + dummy_blocks_headroom = single_kv_block_size_bytes * num_dummy_blocks msg = ( f"Free device memory: {format_bytes(free_hpu_memory)}, " f"{format_bytes(available_hpu_memory)} usable " f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " - f"{format_bytes(dummy_block_headroom)} reserved for KV cache dummy " - f"block {format_bytes(cache_size_bytes - dummy_block_headroom)} " + f"{format_bytes(dummy_blocks_headroom)} reserved for KV cache dummy blocks " + f"{format_bytes(cache_size_bytes-dummy_blocks_headroom)} " "reserved for usable KV cache") logger.info(msg) gc.collect() - return cache_size_bytes - dummy_block_headroom + return cache_size_bytes - dummy_blocks_headroom def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: