@@ -2802,6 +2802,105 @@ def __call__(
28022802 return hidden_states
28032803
28042804
2805+ class CogView4AttnProcessor :
2806+ """
2807+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
2808+ query and key vectors, but does not include spatial normalization.
2809+ """
2810+
2811+ def __init__ (self ):
2812+ if not hasattr (F , "scaled_dot_product_attention" ):
2813+ raise ImportError ("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
2814+
2815+ def __call__ (
2816+ self ,
2817+ attn : Attention ,
2818+ hidden_states : torch .Tensor ,
2819+ encoder_hidden_states : torch .Tensor ,
2820+ attention_mask : Optional [torch .Tensor ] = None ,
2821+ image_rotary_emb : Optional [torch .Tensor ] = None ,
2822+ ) -> torch .Tensor :
2823+ text_seq_length = encoder_hidden_states .size (1 )
2824+
2825+ hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
2826+
2827+ batch_size , sequence_length , _ = hidden_states .shape
2828+
2829+ if attention_mask is not None :
2830+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
2831+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
2832+
2833+ query = attn .to_q (hidden_states )
2834+ key = attn .to_k (hidden_states )
2835+ value = attn .to_v (hidden_states )
2836+
2837+ inner_dim = key .shape [- 1 ]
2838+ head_dim = inner_dim // attn .heads
2839+
2840+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2841+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2842+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2843+
2844+ ###############################################3
2845+ # TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
2846+ linear_qkv_weight = torch .load ("/home/lhy/code/cogview/linear_qkv_weight.pt" )
2847+ linear_qkv_bias = torch .load ("/home/lhy/code/cogview/linear_qkv_bias.pt" )
2848+
2849+ qkv = torch .matmul (hidden_states , linear_qkv_weight .T ) + linear_qkv_bias
2850+ qkv = qkv .view (batch_size , - 1 , attn .heads , head_dim * 3 )
2851+ query , key , value = qkv .chunk (3 , dim = - 1 )
2852+
2853+
2854+ # TODO: 校验rope是否apply正确(目前有25%的误差)
2855+ ###############################################3
2856+
2857+ if attn .norm_q is not None :
2858+ query = attn .norm_q (query )
2859+ if attn .norm_k is not None :
2860+ key = attn .norm_k (key )
2861+
2862+ query = query .transpose (1 , 2 )
2863+ key = key .transpose (1 , 2 )
2864+ value = value .transpose (1 , 2 )
2865+
2866+ # Apply RoPE if needed
2867+ if image_rotary_emb is not None :
2868+ from .embeddings import apply_rotary_emb_megatron
2869+
2870+ query [:, :, text_seq_length :, :] = apply_rotary_emb_megatron (
2871+ query [:, :, text_seq_length :, :], image_rotary_emb
2872+ )
2873+ key [:, :, text_seq_length :, :] = apply_rotary_emb_megatron (
2874+ key [:, :, text_seq_length :, :], image_rotary_emb
2875+ )
2876+
2877+ ##########################################
2878+ query = torch .load ("/home/lhy/code/cogview/query_after_rope.pt" )
2879+ key = torch .load ("/home/lhy/code/cogview/key_after_rope.pt" )
2880+ value = torch .load ("/home/lhy/code/cogview/value_after_rope.pt" )
2881+ query = query [None , :16 + 4096 , ...]
2882+ key = key [None , :16 + 4096 , ...]
2883+ value = value [None , :16 + 4096 , ...]
2884+ query = query .transpose (1 , 2 )
2885+ key = key .transpose (1 , 2 )
2886+ value = value .transpose (1 , 2 )
2887+ ##########################################
2888+
2889+ hidden_states = F .scaled_dot_product_attention (
2890+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
2891+ )
2892+
2893+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2894+
2895+ # linear proj
2896+ hidden_states = attn .to_out [0 ](hidden_states )
2897+
2898+ encoder_hidden_states , hidden_states = hidden_states .split (
2899+ [text_seq_length , hidden_states .size (1 ) - text_seq_length ], dim = 1
2900+ )
2901+ return hidden_states , encoder_hidden_states
2902+
2903+
28052904class CogVideoXAttnProcessor2_0 :
28062905 r"""
28072906 Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -2824,9 +2923,7 @@ def __call__(
28242923
28252924 hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
28262925
2827- batch_size , sequence_length , _ = (
2828- hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
2829- )
2926+ batch_size , sequence_length , _ = hidden_states .shape
28302927
28312928 if attention_mask is not None :
28322929 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
@@ -6174,6 +6271,7 @@ def __call__(
61746271 FusedFluxAttnProcessor2_0 ,
61756272 FusedFluxAttnProcessor2_0_NPU ,
61766273 CogVideoXAttnProcessor2_0 ,
6274+ CogView4AttnProcessor ,
61776275 FusedCogVideoXAttnProcessor2_0 ,
61786276 XFormersAttnAddedKVProcessor ,
61796277 XFormersAttnProcessor ,
0 commit comments