Skip to content

Commit c2db397

Browse files
authored
[VLM] Update mllama traceable definition (#1140)
## Purpose ## * Ensure compatibility with transformers>=4.50 by replacing `num_logits_to_keep` with `logits_to_keep` https://github.com/huggingface/transformers/blob/1fae54c7216e144b426e753400abdc1299d4fc74/src/transformers/models/mllama/modeling_mllama.py#L2025-L2028 ## Changes ## * Update function definition ## Testing ## * Ran mllama example --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5e9edae commit c2db397

File tree

1 file changed

+9
-6
lines changed
  • src/llmcompressor/transformers/tracing

1 file changed

+9
-6
lines changed

src/llmcompressor/transformers/tracing/mllama.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@
3737

3838

3939
# TRACING: This function is not traceable
40-
@torch.fx.wrap
4140
def _prepare_cross_attention_mask(
4241
cross_attention_mask: torch.Tensor,
4342
num_vision_tokens: int,
4443
dtype: str,
4544
) -> Tuple[torch.Tensor, torch.Tensor]:
4645
# reshape so it can be used by attn module
47-
batch_size, text_total_length, *_ = cross_attention_mask.shape
46+
# TRACING: cannot unpack cross_attention_mask with arbitrary number of args
47+
# batch_size, text_total_length, *_ = cross_attention_mask.shape
48+
batch_size, text_total_length = cross_attention_mask.shape[:2]
49+
4850
cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
4951
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
5052
cross_attention_mask = cross_attention_mask.unsqueeze(1)
@@ -66,7 +68,7 @@ def _prepare_cross_attention_mask(
6668
return cross_attention_mask, full_text_row_masked_out_mask
6769

6870

69-
# TRACING: needs to use wrapped _prepare_cross_attention_mask
71+
# TRACING: needs to use updated _prepare_cross_attention_mask
7072
@add_start_docstrings(
7173
"""The Mllama model which consists of a vision encoder and a language model.""",
7274
MLLAMA_START_DOCSTRING,
@@ -90,7 +92,8 @@ def forward(
9092
output_hidden_states: Optional[bool] = None,
9193
return_dict: Optional[bool] = None,
9294
cache_position: Optional[torch.LongTensor] = None,
93-
num_logits_to_keep: int = 0,
95+
num_logits_to_keep: int = 0, # backwards compatibility
96+
logits_to_keep: Union[int, torch.Tensor] = 0,
9497
) -> Union[Tuple, CausalLMOutputWithPast]:
9598
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
9699
output_hidden_states = (
@@ -127,7 +130,7 @@ def forward(
127130
)
128131

129132
if cross_attention_mask is not None:
130-
# TRACING: use wrapped function
133+
# TRACING: use updated function
131134
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
132135
cross_attention_mask,
133136
num_vision_tokens=self.vision_model.num_patches,
@@ -155,7 +158,7 @@ def forward(
155158
output_attentions=output_attentions,
156159
return_dict=return_dict,
157160
cache_position=cache_position,
158-
num_logits_to_keep=num_logits_to_keep,
161+
logits_to_keep=logits_to_keep,
159162
)
160163

161164
return outputs

0 commit comments

Comments
 (0)