last_chunk_indices caclulations fix#1023
last_chunk_indices caclulations fix#1023jbyczkow wants to merge 2 commits intovllm-project:releases/v0.16.0from
Conversation
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
There was a problem hiding this comment.
Pull request overview
This PR aims to fix how last_chunk_indices are computed for Mamba/Mamba2 chunked prefill in the HPU model runner, which is used to select the correct end-of-sequence SSM state to persist into the cache.
Changes:
- Replaces per-request chunk boundary construction with a simplified
target_seq/chunk_size-based computation. - Adds assertions around
chunk_sizeand the derived number of physical chunks.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| assert this_new_tokens == 0 | ||
| last_chunk_indices.append(last_chunk_index) | ||
| assert chunk_size > 0 |
There was a problem hiding this comment.
The calculation of nphysical_chunks uses floor division (target_seq // chunk_size) without asserting that target_seq is chunk-aligned. The underlying mamba chunking implementation reshapes by chunk_size (and will truncate/mis-shape if target_seq isn’t a multiple), so this should either enforce target_seq % chunk_size == 0 (or round up and ensure inputs/padding follow the same rule) to avoid silent mis-indexing/crashes in the mamba kernels.
| assert chunk_size > 0 | |
| assert chunk_size > 0 | |
| assert target_seq % chunk_size == 0, ( | |
| f"target_seq={target_seq} must be a multiple of chunk_size={chunk_size}" | |
| ) |
| chunk_size = self.model_config.get_mamba_chunk_size() | ||
| for req_idx in range(len(contents.req_ids)): | ||
| this_num_computed = num_computed_tokens_p_cpu[req_idx].item() | ||
| this_new_tokens = (query_start_loc_p_cpu[req_idx + 1].item() - query_start_loc_p_cpu[req_idx].item()) | ||
|
|
||
| # if computed tokens are not chunk-aligned, use the first | ||
| # chunk to finish it off | ||
| if this_num_computed % chunk_size != 0: | ||
| # how many tokens to finish the chunk? | ||
| last_chunk_index += 1 | ||
| chunk_len = (cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed) | ||
| # we can only use at most this_new_tokens | ||
| chunk_len = min(chunk_len, this_new_tokens) | ||
| seqlen_pos += chunk_len | ||
| this_new_tokens -= chunk_len | ||
|
|
||
| n_chunks = cdiv(this_new_tokens, chunk_size) | ||
| for chunk in range(n_chunks): | ||
| last_chunk_index += 1 | ||
| chunk_len = min(chunk_size, this_new_tokens) | ||
| seqlen_pos += chunk_len | ||
| this_new_tokens -= chunk_len | ||
|
|
||
| assert this_new_tokens == 0 | ||
| last_chunk_indices.append(last_chunk_index) | ||
| assert chunk_size > 0 | ||
| nphysical_chunks = target_seq // chunk_size | ||
| assert nphysical_chunks > 0, (f"target_seq={target_seq} must be >= chunk_size={chunk_size}") | ||
| last_chunk_indices = [nphysical_chunks - 1 for _ in range(len(contents.req_ids))] |
There was a problem hiding this comment.
This change alters how mamba chunk boundaries are derived, but there are no unit tests asserting last_chunk_indices_p behavior for multi-request prefill (especially with num_computed_tokens > 0 / chunk misalignment). Adding a focused test that exercises _form_prefill_batch for a small synthetic batch and validates the produced indices would help prevent regressions here.
| assert chunk_size > 0 | ||
| nphysical_chunks = target_seq // chunk_size | ||
| assert nphysical_chunks > 0, (f"target_seq={target_seq} must be >= chunk_size={chunk_size}") | ||
| last_chunk_indices = [nphysical_chunks - 1 for _ in range(len(contents.req_ids))] |
There was a problem hiding this comment.
last_chunk_indices is now set to the same value for every request (nphysical_chunks - 1). In the mamba path, hpu_mamba_chunk_scan_combined_varlen(...)[last_chunk_indices_p] indexes the returned per-chunk state tensor; indices therefore must point at each sequence’s last chunk in the flattened chunk stream (and must differ per request when there are multiple sequences and/or non-zero num_computed_tokens). With the current code, all requests will read the state from the same chunk (typically the end of the first sequence), leading to incorrect ssm_state updates.
| last_chunk_indices = [nphysical_chunks - 1 for _ in range(len(contents.req_ids))] | |
| # For the mamba kernel, chunks are flattened across sequences. | |
| # Assuming all sequences contribute `nphysical_chunks` chunks contiguously, | |
| # the last chunk for request i is at index: i * nphysical_chunks + (nphysical_chunks - 1). | |
| last_chunk_indices = [ | |
| i * nphysical_chunks + (nphysical_chunks - 1) | |
| for i in range(len(contents.req_ids)) | |
| ] |
No description provided.