File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
onnx_diagnostic/torch_export_patches/patches Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -1901,10 +1901,7 @@ def get_placeholder_mask(
19011901try :
19021902 import transformers .modeling_utils
19031903
1904- # TODO(titaiwang): This is not ready yet.
1905- # Using multi-turn conversation to export, we don't need to rewrite the attention
1906- # as sequence_length is not restricted to 1.
1907- patch_modeling_utils = False
1904+ patch_modeling_utils = True
19081905
19091906 from transformers .integrations .sdpa_attention import use_gqa_in_sdpa , repeat_kv
19101907
@@ -1948,6 +1945,10 @@ def patched_sdpa_attention_forward(
19481945 if torch .jit .is_tracing () and isinstance (is_causal , torch .Tensor ):
19491946 is_causal = is_causal .item ()
19501947
1948+ # From causal_mask generation, attention_mask is 4D, and the last dim
1949+ # should be the same as key's seq_len
1950+ torch ._check (attention_mask .shape [3 ] == key .shape [2 ])
1951+
19511952 attn_output = torch .nn .functional .scaled_dot_product_attention (
19521953 query ,
19531954 key ,
You can’t perform that action at this time.
0 commit comments