@@ -818,3 +818,152 @@ def forward(self, x, position_ids):
818818 sin = emb .sin () * self .attention_scaling
819819
820820 return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
821+
822+
823+ class patched_IdeficsEmbedding (torch .nn .Module ):
824+ _PATCHES_ = ["forward" ]
825+ _PATCHED_CLASS_ = transformers .models .idefics .modeling_idefics .IdeficsEmbedding
826+
827+ def forward (self , x , seq_len = None ):
828+ # x: [bs, num_attention_heads, seq_len, head_size]
829+ # if seq_len > self.max_seq_len_cached:
830+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
831+
832+ def _set_cos_sin_cache_then (x , inv_freq , seq_len , _cos_cached , _sin_cached ):
833+ t = torch .arange (seq_len , device = x .device , dtype = torch .int64 ).type_as (inv_freq )
834+ freqs = torch .einsum ("i,j->ij" , t , inv_freq )
835+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
836+ return emb .cos ().to (x .dtype ), emb .sin ().to (x .dtype )
837+
838+ def _set_cos_sin_cache_else (_x , _inv_freq , _seq_len , cos_cached , sin_cached ):
839+ torch ._check (seq_len .item () <= cos_cached .shape [0 ])
840+ co = cos_cached [: seq_len .item ()].detach ().clone ()
841+ torch ._check (seq_len .item () <= sin_cached .shape [0 ])
842+ si = sin_cached [: seq_len .item ()].detach ().clone ()
843+ return co .to (dtype = x .dtype ), si .to (dtype = x .dtype )
844+
845+ cos_cached , sin_cached = torch .cond (
846+ (seq_len > self .max_seq_len_cached ).item (),
847+ _set_cos_sin_cache_then ,
848+ _set_cos_sin_cache_else ,
849+ [x , self .inv_freq , seq_len , self .cos_cached , self .sin_cached ],
850+ )
851+ return cos_cached , sin_cached
852+
853+
854+ class patched_IdeficsAttention (torch .nn .Module ):
855+ _PATCHES_ = ["forward" ]
856+ _PATCHED_CLASS_ = transformers .models .idefics .modeling_idefics .IdeficsAttention
857+
858+ def forward (
859+ self ,
860+ hidden_states : torch .Tensor ,
861+ key_value_states : Optional [torch .Tensor ] = None ,
862+ attention_mask : Optional [torch .Tensor ] = None ,
863+ position_ids : Optional [torch .LongTensor ] = None ,
864+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
865+ output_attentions : bool = False ,
866+ use_cache : bool = False ,
867+ cache_position : Optional [torch .LongTensor ] = None ,
868+ ** kwargs ,
869+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
870+ # if key_value_states are provided this layer is used as a cross-attention layer
871+ is_cross_attention = self .is_cross_attention or key_value_states is not None
872+
873+ bsz , q_len , _ = hidden_states .size ()
874+
875+ query_states = (
876+ self .q_proj (hidden_states )
877+ .view (bsz , q_len , self .num_heads , self .head_dim )
878+ .transpose (1 , 2 )
879+ )
880+ if not is_cross_attention :
881+ key_states = (
882+ self .k_proj (hidden_states )
883+ .view (bsz , q_len , self .num_heads , self .head_dim )
884+ .transpose (1 , 2 )
885+ )
886+ value_states = (
887+ self .v_proj (hidden_states )
888+ .view (bsz , q_len , self .num_heads , self .head_dim )
889+ .transpose (1 , 2 )
890+ )
891+ else :
892+ _ , kv_len , _ = (
893+ key_value_states .size ()
894+ ) # Note that, in this case, `kv_len` == `kv_seq_len`
895+ key_states = (
896+ self .k_proj (key_value_states )
897+ .view (bsz , kv_len , self .num_heads , self .head_dim )
898+ .transpose (1 , 2 )
899+ )
900+ value_states = (
901+ self .v_proj (key_value_states )
902+ .view (bsz , kv_len , self .num_heads , self .head_dim )
903+ .transpose (1 , 2 )
904+ )
905+
906+ kv_seq_len = key_states .shape [- 2 ]
907+ if past_key_value is not None :
908+ kv_seq_len += cache_position [0 ]
909+
910+ if not is_cross_attention :
911+ rotary_length = torch .maximum (
912+ torch .tensor (kv_seq_len , dtype = torch .int64 ),
913+ torch .tensor (q_len , dtype = torch .int64 ),
914+ )
915+ cos , sin = self .rotary_emb (value_states , seq_len = rotary_length )
916+ query_states , key_states = (
917+ transformers .models .idefics .modeling_idefics .apply_rotary_pos_emb (
918+ query_states , key_states , cos , sin , position_ids
919+ )
920+ )
921+ # [bsz, nh, t, hd]
922+
923+ if past_key_value is not None :
924+ # sin and cos are specific to RoPE models;
925+ # cache_position needed for the static cache
926+ cache_kwargs = {"cache_position" : cache_position }
927+ key_states , value_states = past_key_value .update (
928+ key_states , value_states , self .layer_idx , cache_kwargs
929+ )
930+
931+ if self .qk_layer_norms :
932+ query_states = self .q_layer_norm (query_states )
933+ key_states = self .k_layer_norm (key_states )
934+
935+ attention_interface : Callable = (
936+ transformers .models .idefics .modeling_idefics .eager_attention_forward
937+ )
938+
939+ if self .config ._attn_implementation != "eager" :
940+ if self .config ._attn_implementation == "sdpa" and output_attentions :
941+ transformers .models .idefics .modeling_idefics .logger .warning_once (
942+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
943+ "`output_attentions=True`. Falling back to "
944+ "eager attention. This warning can be removed using the argument "
945+ '`attn_implementation="eager"` when loading the model.'
946+ )
947+ else :
948+ attention_interface = transformers .modeling_utils .ALL_ATTENTION_FUNCTIONS [
949+ self .config ._attn_implementation
950+ ]
951+
952+ attn_output , attn_weights = attention_interface (
953+ self ,
954+ query_states ,
955+ key_states ,
956+ value_states ,
957+ attention_mask ,
958+ dropout = 0.0 if not self .training else self .dropout ,
959+ scaling = self .scaling ,
960+ ** kwargs ,
961+ )
962+
963+ attn_output = attn_output .reshape (bsz , q_len , - 1 ).contiguous ()
964+ attn_output = self .o_proj (attn_output )
965+
966+ if output_attentions :
967+ attn_weights = None
968+
969+ return attn_output , attn_weights , past_key_value
0 commit comments