@@ -359,23 +359,6 @@ def __init__(
359359 AttentionBackendEnum .ROCM_AITER_FA ,
360360 }
361361
362- def split_qkv (self , qkv : torch .Tensor ) -> tuple [torch .Tensor , ...]:
363- # [s, b, 3 * head * head_dim]
364- seq_len , bs , _ = qkv .shape
365-
366- # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
367- q , k , v = qkv .chunk (3 , dim = 2 )
368-
369- # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
370- new_shape = (
371- seq_len ,
372- bs ,
373- self .num_attention_heads_per_partition ,
374- self .hidden_size_per_attention_head ,
375- )
376- q , k , v = (x .view (* new_shape ) for x in (q , k , v ))
377- return q , k , v
378-
379362 def forward (
380363 self ,
381364 x : torch .Tensor ,
@@ -386,17 +369,32 @@ def forward(
386369 ) -> torch .Tensor :
387370 # [s, b, c] --> [s, b, head * 3 * head_dim]
388371 x , _ = self .qkv (x )
372+ seq_len , batch_size , _ = x .shape
389373
390- # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
391- q , k , v = self .split_qkv (x )
392- batch_size = q .shape [1 ]
374+ qkv = einops .rearrange (
375+ x ,
376+ "s b (three head head_dim) -> b s three head head_dim" ,
377+ three = 3 ,
378+ head = self .num_attention_heads_per_partition ,
379+ )
393380
394- q , k , v = (einops .rearrange (x , "s b ... -> b s ..." ) for x in (q , k , v ))
395381 if rotary_pos_emb is not None :
396- # [2 * b, s, heads, head_dim]
397- qk_concat = torch .cat ([q , k ], dim = 0 )
398- qk_rotated = apply_rotary_pos_emb_vision (qk_concat , rotary_pos_emb )
399- q , k = torch .chunk (qk_rotated , 2 , dim = 0 )
382+ qk , v = qkv [:, :, :2 ], qkv [:, :, 2 ]
383+
384+ qk_reshaped = einops .rearrange (
385+ qk , "b s two head head_dim -> (two b) s head head_dim" , two = 2
386+ )
387+ qk_rotated = apply_rotary_pos_emb_vision (qk_reshaped , rotary_pos_emb )
388+ qk_rotated = qk_rotated .view (
389+ 2 ,
390+ batch_size ,
391+ seq_len ,
392+ self .num_attention_heads_per_partition ,
393+ self .hidden_size_per_attention_head ,
394+ )
395+ q , k = qk_rotated .unbind (dim = 0 )
396+ else :
397+ q , k , v = qkv .unbind (dim = 2 )
400398
401399 if self .is_flash_attn_backend :
402400 context_layer = vit_flash_attn_wrapper (
0 commit comments