Skip to content

Conversation

xyz2606
Copy link

@xyz2606 xyz2606 commented Oct 2, 2025

What this PR does / why we need it?

A faster chunked gated delta rule implementation by basic torch operators.
(1) faster inverse of a tril matrix
(2) faster computation of attn score and recurrent state for each chunk

  • Fixes #
    num_qk_heads and sequence_length were reversed
    wrong initialization of last_recurrent_state when initial_state is None (# q heads were used instread of # v heads)
    -->

Does this PR introduce any user-facing change?

No

How was this patch tested?

tested by random input of batch_size = 1 & seq_len=8192

Copy link

github-actions bot commented Oct 2, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the torch_chunk_gated_delta_rule function for performance and correctness. It introduces a faster, parallel algorithm for triangular matrix inversion and vectorizes computations within the chunk processing loop. The changes also include several important bug fixes related to tensor shapes, padding, and state initialization. Overall, this is a great improvement. I have one critical piece of feedback regarding an assertion needed to ensure the new matrix inversion algorithm's correctness.

attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)

lg = int(log(chunk_size, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new matrix inversion algorithm, which uses log(chunk_size, 2), implicitly assumes that chunk_size is a power of two. If a non-power-of-two chunk_size is provided, the view operation inside the loop will likely fail due to a shape mismatch, causing a runtime error. To ensure correctness and prevent such errors, it's critical to add an assertion that validates chunk_size is a power of two.

Suggested change
lg = int(log(chunk_size, 2))
assert (chunk_size & (chunk_size - 1) == 0) and chunk_size > 0, "chunk_size must be a power of 2"
lg = int(log(chunk_size, 2))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant