3737
3838
3939# TRACING: This function is not traceable
40- @torch .fx .wrap
4140def _prepare_cross_attention_mask (
4241 cross_attention_mask : torch .Tensor ,
4342 num_vision_tokens : int ,
4443 dtype : str ,
4544) -> Tuple [torch .Tensor , torch .Tensor ]:
4645 # reshape so it can be used by attn module
47- batch_size , text_total_length , * _ = cross_attention_mask .shape
46+ # TRACING: cannot unpack cross_attention_mask with arbitrary number of args
47+ # batch_size, text_total_length, *_ = cross_attention_mask.shape
48+ batch_size , text_total_length = cross_attention_mask .shape [:2 ]
49+
4850 cross_attention_mask = cross_attention_mask .repeat_interleave (num_vision_tokens , dim = 3 )
4951 cross_attention_mask = cross_attention_mask .view (batch_size , text_total_length , - 1 )
5052 cross_attention_mask = cross_attention_mask .unsqueeze (1 )
@@ -66,7 +68,7 @@ def _prepare_cross_attention_mask(
6668 return cross_attention_mask , full_text_row_masked_out_mask
6769
6870
69- # TRACING: needs to use wrapped _prepare_cross_attention_mask
71+ # TRACING: needs to use updated _prepare_cross_attention_mask
7072@add_start_docstrings (
7173 """The Mllama model which consists of a vision encoder and a language model.""" ,
7274 MLLAMA_START_DOCSTRING ,
@@ -90,7 +92,8 @@ def forward(
9092 output_hidden_states : Optional [bool ] = None ,
9193 return_dict : Optional [bool ] = None ,
9294 cache_position : Optional [torch .LongTensor ] = None ,
93- num_logits_to_keep : int = 0 ,
95+ num_logits_to_keep : int = 0 , # backwards compatibility
96+ logits_to_keep : Union [int , torch .Tensor ] = 0 ,
9497 ) -> Union [Tuple , CausalLMOutputWithPast ]:
9598 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
9699 output_hidden_states = (
@@ -127,7 +130,7 @@ def forward(
127130 )
128131
129132 if cross_attention_mask is not None :
130- # TRACING: use wrapped function
133+ # TRACING: use updated function
131134 cross_attention_mask , full_text_row_masked_out_mask = _prepare_cross_attention_mask (
132135 cross_attention_mask ,
133136 num_vision_tokens = self .vision_model .num_patches ,
@@ -155,7 +158,7 @@ def forward(
155158 output_attentions = output_attentions ,
156159 return_dict = return_dict ,
157160 cache_position = cache_position ,
158- num_logits_to_keep = num_logits_to_keep ,
161+ logits_to_keep = logits_to_keep ,
159162 )
160163
161164 return outputs
0 commit comments