diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/fla.py index b200c6747d..641398689d 100644 --- a/vllm_ascend/ops/fla.py +++ b/vllm_ascend/ops/fla.py @@ -11,7 +11,7 @@ import triton from vllm.model_executor.layers.fla.ops.layernorm_guard import \ layer_norm_fwd_kernel - +from math import log def _layer_norm_fwd( x, @@ -127,7 +127,7 @@ def torch_chunk_gated_delta_rule( value, g, beta, - chunk_size=64, + chunk_size=128, initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=False, @@ -141,15 +141,16 @@ def torch_chunk_gated_delta_rule( for x in (query, key, value, beta, g) ] - batch_size, sequence_length, num_heads, k_head_dim = key.shape + batch_size, num_qk_heads, sequence_length, k_head_dim = key.shape + num_v_heads = value.shape[1] v_head_dim = value.shape[-1] - pad_size = (chunk_size - num_heads % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(num_v_heads // num_qk_heads, dim=1) + key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(num_v_heads // num_qk_heads, dim=1) value = F.pad(value, (0, 0, 0, pad_size)) beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) - tot_heads = num_heads + pad_size + sequence_length_padded = sequence_length + pad_size scale = 1 / (query.shape[-1]**0.5) query = query * scale @@ -173,15 +174,28 @@ def torch_chunk_gated_delta_rule( g.unsqueeze(-2)).tril().exp().float()).tril() attn = -( (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + lg = int(log(chunk_size, 2)) + + block_size = 1 + attn_inv = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device).repeat((tuple(attn.shape)[:-2] + (1, 1))) + attn = attn_inv - attn + for i in range(lg): + block_num = chunk_size // block_size + prod = attn @ attn_inv + attn_inv_block = attn_inv.view(tuple(attn.shape)[:-2] + (block_num, block_size, block_num, block_size)).transpose(-2, -3) + prod_block = prod.view(tuple(attn.shape)[:-2] + (block_num, block_size, block_num, block_size)).transpose(-2, -3) + r0 = torch.arange(block_num // 2, device=attn.device) * 2 + r1 = r0 + 1 + attn_inv_block[:, :, :, r1, r0, :, :] = -attn_inv_block[..., r1, r1, :, :] @ prod_block[..., r1, r0, :, :] + attn_inv = attn_inv_block.transpose(-2, -3).view(tuple(attn_inv_block.shape)[:-4] + (chunk_size, chunk_size)) + block_size *= 2 + attn = attn_inv + value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - last_recurrent_state = (torch.zeros(batch_size, sequence_length, + last_recurrent_state = (torch.zeros(batch_size, num_v_heads, k_head_dim, v_head_dim).to(value) if initial_state is None else initial_state.to(value)) @@ -191,28 +205,41 @@ def torch_chunk_gated_delta_rule( dtype=torch.bool, device=query.device), diagonal=1) + query_view = query.reshape(query.shape[0], query.shape[1], -1, chunk_size, query.shape[-1]) + key_trans = key.reshape(key.shape[0], key.shape[1], -1, chunk_size, key.shape[-1]).transpose(-1, -2) + qk = query_view @ key_trans + attn_score = qk * decay_mask.masked_fill_(mask, 0) + + gexp = g[:, :, :, :, None].exp() + qgexp = query * gexp + + kgexp = (g[:, :, :, -1, None] - g[:, :, :]).exp()[..., None] + kgexp = key * kgexp + + k_cumdecay_qgexp = torch.cat([k_cumdecay, qgexp], dim=3) + v_new_out = torch.zeros_like(value) + attn_inter_out = torch.zeros_like(value) # for each chunk - for i in range(0, tot_heads // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * - decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + for i in range(0, sequence_length_padded // chunk_size): + v_i = value[:, :, i] + attn = attn_score[:, :, i] + v_prime_attn_inter = (k_cumdecay_qgexp[:, :, i]) @ last_recurrent_state + v_prime = v_prime_attn_inter[:, :, :chunk_size] + attn_inter = v_prime_attn_inter[:, :, chunk_size:] v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() + - (k_i * - (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( - -1, -2) @ v_new) + v_new_out[:, :, i] = v_new + attn_inter_out[:, :, i] = attn_inter + last_recurrent_state *= gexp[:, :, i, -1, :, None] + last_recurrent_state += (kgexp[:, :, i]).transpose(-1, -2) @ v_new + core_attn_out = attn_inter_out + attn_score @ v_new_out if not output_final_state: last_recurrent_state = None core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) - core_attn_out = core_attn_out[:, :, :num_heads] + core_attn_out = core_attn_out[:, :, :sequence_length] core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) return core_attn_out, last_recurrent_state