last_chunk_indices caclulations fix#1024
last_chunk_indices caclulations fix#1024jbyczkow wants to merge 2 commits intovllm-project:releases/v0.15.1from
Conversation
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
There was a problem hiding this comment.
Pull request overview
This PR updates how last_chunk_indices are computed for Mamba chunked-prefill metadata in the HPU model runner.
Changes:
- Replaced the per-request
last_chunk_indicescomputation loop with a simplified calculation derived fromtarget_seqandchunk_size. - Added assertions around
chunk_sizeand the derived chunk count.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| assert chunk_size > 0 | ||
| nphysical_chunks = target_seq // chunk_size | ||
| assert nphysical_chunks > 0, (f"target_seq={target_seq} must be >= chunk_size={chunk_size}") | ||
| last_chunk_indices = [nphysical_chunks - 1 for _ in range(len(contents.req_ids))] |
There was a problem hiding this comment.
last_chunk_indices is now set to the same value for every request (nphysical_chunks - 1). Downstream (hpu_mamba_mixer2.py) this tensor is used to index the per-chunk SSM states and then written back into ssm_state per request, so making all requests point at the same chunk will cause incorrect state updates (and can effectively make multiple sequences share the same final state). Compute per-request chunk indices based on how tokens are laid out (e.g., in the non-merged 2D layout, the i-th sequence’s last chunk should be offset by i * nphysical_chunks).
| last_chunk_indices = [nphysical_chunks - 1 for _ in range(len(contents.req_ids))] | |
| last_chunk_indices = [ | |
| i * nphysical_chunks + (nphysical_chunks - 1) | |
| for i in range(len(contents.req_ids)) | |
| ] |
| nphysical_chunks = target_seq // chunk_size | ||
| assert nphysical_chunks > 0, (f"target_seq={target_seq} must be >= chunk_size={chunk_size}") |
There was a problem hiding this comment.
nphysical_chunks = target_seq // chunk_size uses floor division, but the mamba chunk-scan implementation reshapes based on seqlen // chunk_size and effectively requires the sequence length to be an exact multiple of chunk_size. It would be safer to assert target_seq % chunk_size == 0 (with a clear error) or explicitly round/pad target_seq up to the next multiple before computing chunk indices, otherwise disabling/altering bucketing can lead to hard-to-debug shape errors or incorrect indexing.
| nphysical_chunks = target_seq // chunk_size | |
| assert nphysical_chunks > 0, (f"target_seq={target_seq} must be >= chunk_size={chunk_size}") | |
| # Mamba chunk-scan requires the sequence length to be an exact multiple of chunk_size. | |
| assert target_seq % chunk_size == 0, ( | |
| f"Mamba chunking requires target_seq ({target_seq}) to be an exact multiple " | |
| f"of chunk_size ({chunk_size})." | |
| ) | |
| nphysical_chunks = target_seq // chunk_size | |
| assert nphysical_chunks > 0, ( | |
| f"target_seq={target_seq} must be >= chunk_size={chunk_size}" | |
| ) |
No description provided.