Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 53 additions & 26 deletions vllm_ascend/ops/fla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import triton
from vllm.model_executor.layers.fla.ops.layernorm_guard import \
layer_norm_fwd_kernel

from math import log

def _layer_norm_fwd(
x,
Expand Down Expand Up @@ -127,7 +127,7 @@ def torch_chunk_gated_delta_rule(
value,
g,
beta,
chunk_size=64,
chunk_size=128,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
Expand All @@ -141,15 +141,16 @@ def torch_chunk_gated_delta_rule(
for x in (query, key, value, beta, g)
]

batch_size, sequence_length, num_heads, k_head_dim = key.shape
batch_size, num_qk_heads, sequence_length, k_head_dim = key.shape
num_v_heads = value.shape[1]
v_head_dim = value.shape[-1]
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(num_v_heads // num_qk_heads, dim=1)
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(num_v_heads // num_qk_heads, dim=1)
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
tot_heads = num_heads + pad_size
sequence_length_padded = sequence_length + pad_size
scale = 1 / (query.shape[-1]**0.5)
query = query * scale

Expand All @@ -173,15 +174,28 @@ def torch_chunk_gated_delta_rule(
g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -(
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)

lg = int(log(chunk_size, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new matrix inversion algorithm, which uses log(chunk_size, 2), implicitly assumes that chunk_size is a power of two. If a non-power-of-two chunk_size is provided, the view operation inside the loop will likely fail due to a shape mismatch, causing a runtime error. To ensure correctness and prevent such errors, it's critical to add an assertion that validates chunk_size is a power of two.

Suggested change
lg = int(log(chunk_size, 2))
assert (chunk_size & (chunk_size - 1) == 0) and chunk_size > 0, "chunk_size must be a power of 2"
lg = int(log(chunk_size, 2))


block_size = 1
attn_inv = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device).repeat((tuple(attn.shape)[:-2] + (1, 1)))
attn = attn_inv - attn
for i in range(lg):
block_num = chunk_size // block_size
prod = attn @ attn_inv
attn_inv_block = attn_inv.view(tuple(attn.shape)[:-2] + (block_num, block_size, block_num, block_size)).transpose(-2, -3)
prod_block = prod.view(tuple(attn.shape)[:-2] + (block_num, block_size, block_num, block_size)).transpose(-2, -3)
r0 = torch.arange(block_num // 2, device=attn.device) * 2
r1 = r0 + 1
attn_inv_block[:, :, :, r1, r0, :, :] = -attn_inv_block[..., r1, r1, :, :] @ prod_block[..., r1, r0, :, :]
attn_inv = attn_inv_block.transpose(-2, -3).view(tuple(attn_inv_block.shape)[:-4] + (chunk_size, chunk_size))
block_size *= 2
attn = attn_inv

value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))

last_recurrent_state = (torch.zeros(batch_size, sequence_length,
last_recurrent_state = (torch.zeros(batch_size, num_v_heads,
k_head_dim, v_head_dim).to(value) if
initial_state is None else initial_state.to(value))

Expand All @@ -191,28 +205,41 @@ def torch_chunk_gated_delta_rule(
dtype=torch.bool,
device=query.device),
diagonal=1)
query_view = query.reshape(query.shape[0], query.shape[1], -1, chunk_size, query.shape[-1])
key_trans = key.reshape(key.shape[0], key.shape[1], -1, chunk_size, key.shape[-1]).transpose(-1, -2)
qk = query_view @ key_trans
attn_score = qk * decay_mask.masked_fill_(mask, 0)

gexp = g[:, :, :, :, None].exp()
qgexp = query * gexp

kgexp = (g[:, :, :, -1, None] - g[:, :, :]).exp()[..., None]
kgexp = key * kgexp

k_cumdecay_qgexp = torch.cat([k_cumdecay, qgexp], dim=3)
v_new_out = torch.zeros_like(value)
attn_inter_out = torch.zeros_like(value)

# for each chunk
for i in range(0, tot_heads // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) *
decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
for i in range(0, sequence_length_padded // chunk_size):
v_i = value[:, :, i]
attn = attn_score[:, :, i]
v_prime_attn_inter = (k_cumdecay_qgexp[:, :, i]) @ last_recurrent_state
v_prime = v_prime_attn_inter[:, :, :chunk_size]
attn_inter = v_prime_attn_inter[:, :, chunk_size:]
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
(k_i *
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
-1, -2) @ v_new)
v_new_out[:, :, i] = v_new
attn_inter_out[:, :, i] = attn_inter
last_recurrent_state *= gexp[:, :, i, -1, :, None]
last_recurrent_state += (kgexp[:, :, i]).transpose(-1, -2) @ v_new
core_attn_out = attn_inter_out + attn_score @ v_new_out

if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
core_attn_out.shape[1], -1,
core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :num_heads]
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1,
2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
Loading