Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 5 additions & 28 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,35 +2163,12 @@ def _form_prefill_batch(self, contents):
# of prefix caching for mamba2 (wip). We need to take care
# of the interaction with chunked prefill in order to
# satisfy constraint (2).
# TODO (tdoublep): This code could probably be optimized.
last_chunk_indices = []
last_chunk_index = -1
seqlen_pos = 0

chunk_size = self.model_config.get_mamba_chunk_size()
for req_idx in range(len(contents.req_ids)):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (query_start_loc_p_cpu[req_idx + 1].item() - query_start_loc_p_cpu[req_idx].item())

# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % chunk_size != 0:
# how many tokens to finish the chunk?
last_chunk_index += 1
chunk_len = (cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len

n_chunks = cdiv(this_new_tokens, chunk_size)
for chunk in range(n_chunks):
last_chunk_index += 1
chunk_len = min(chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len

assert this_new_tokens == 0
last_chunk_indices.append(last_chunk_index)
assert chunk_size > 0
nphysical_chunks = target_seq // chunk_size
assert nphysical_chunks > 0, (f"target_seq={target_seq} must be >= chunk_size={chunk_size}")
last_chunk_indices = [nphysical_chunks - 1 for _ in range(len(contents.req_ids))]

num_prefill_reqs = len(contents.req_ids)
all_state_indices_cpu = []
Expand Down
Loading