Conversation
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
There was a problem hiding this comment.
Pull request overview
This pull request simplifies the computation of last_chunk_indices for Mamba layer processing in the prefill batch formation code. The change removes a complex per-request chunk tracking algorithm and replaces it with a uniform calculation based on the padded sequence length.
Changes:
- Removed complex per-request chunk index tracking logic that accounted for non-aligned computed tokens and variable query lengths
- Replaced with simplified calculation using
target_seq // chunk_sizeapplied uniformly to all requests
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| last_chunk_indices = [nphysical_chunks - 1 | ||
| for _ in range(len(contents.req_ids))] | ||
|
|
There was a problem hiding this comment.
The simplified logic for computing last_chunk_indices is incorrect. It assumes all requests have the same last chunk index based on target_seq, but different requests can have different actual query lengths (before padding).
For example, if target_seq=128, chunk_size=64, the new code sets last_chunk_indices = [1, 1, ...] for all requests. However, if request 0 has query_len=100 (2 chunks, last_chunk_index=1) and request 1 has query_len=50 (1 chunk, last_chunk_index=0), then request 1 would incorrectly try to retrieve state from chunk 1, which only contains padding tokens.
The old logic correctly computed last_chunk_indices per request by iterating through each request's actual query length and num_computed_tokens. This per-request calculation is necessary because last_chunk_indices_p is used at line 440 in hpu_mamba_mixer2.py to index into varlen_states and retrieve the final state for each request.
| last_chunk_indices = [nphysical_chunks - 1 | |
| for _ in range(len(contents.req_ids))] | |
| # Compute last_chunk_indices per request, based on each request's | |
| # actual (unpadded) query length and any already-computed tokens. | |
| # We fall back to target_seq / 0 computed tokens if per-request | |
| # metadata is not available, to preserve previous behavior. | |
| num_reqs = len(contents.req_ids) | |
| # Try to obtain per-request query lengths; otherwise assume target_seq. | |
| if hasattr(contents, "seq_lens"): | |
| query_lens = contents.seq_lens | |
| elif hasattr(contents, "query_lens"): | |
| query_lens = contents.query_lens | |
| else: | |
| query_lens = [target_seq] * num_reqs | |
| # Try to obtain per-request num_computed_tokens; otherwise assume 0. | |
| num_computed_tokens = getattr(contents, "num_computed_tokens", None) | |
| if num_computed_tokens is None: | |
| per_req_computed = [0] * num_reqs | |
| else: | |
| per_req_computed = num_computed_tokens | |
| last_chunk_indices = [] | |
| for i in range(num_reqs): | |
| qlen = int(query_lens[i]) | |
| computed = int(per_req_computed[i]) if i < len(per_req_computed) else 0 | |
| effective_len = max(qlen - computed, 0) | |
| if effective_len <= 0: | |
| # No tokens to process; default to first chunk index 0. | |
| last_chunk_idx = 0 | |
| else: | |
| last_token_idx = effective_len - 1 | |
| last_chunk_idx = last_token_idx // chunk_size | |
| last_chunk_indices.append(last_chunk_idx) |
| last_chunk_indices.append(last_chunk_index) | ||
| assert chunk_size > 0 | ||
| nphysical_chunks = target_seq // chunk_size | ||
| assert nphysical_chunks > 0, ( |
There was a problem hiding this comment.
formatting issues:
assert nphysical_chunks > 0, (f"target_seq={target_seq} must be >= chunk_size={chunk_size}") last_chunk_indices = [nphysical_chunks - 1 for _ in range(len(contents.req_ids))]
Signed-off-by: Jakub Byczkowski <jbyczkowski@habana.ai>
Signed-off-by: Jakub Byczkowski <jbyczkowski@habana.ai>
|
@adobrzyn please cherypick to 0.16/0.15.1 |
No description provided.