Skip to content

Commit 72c6999

Browse files
committed
fix
1 parent f65eab9 commit 72c6999

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,13 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self):
387387
expected = instance.forward(**inputs)
388388
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
389389
self.assertEqualArray(expected, got)
390-
with fake_torchdynamo_exporting():
391-
assert (
392-
_is_torchdynamo_exporting()
393-
), f"exporting is not set to true? {torch.compiler.is_exporting_flag}"
394-
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
395-
self.assertEqualArray(expected, got)
390+
if 1: # with torch_export_patches(patch_transformers=False, patch_torch=True):
391+
with fake_torchdynamo_exporting():
392+
assert (
393+
_is_torchdynamo_exporting()
394+
), f"exporting is not set to true? {torch.compiler.is_exporting_flag}"
395+
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
396+
self.assertEqualArray(expected, got)
396397

397398

398399
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,18 @@ def patched_sdpa_attention_forward(
14011401
is_causal = attention_mask is None and is_causal
14021402

14031403
if not is_causal:
1404+
torch._check(query.shape[0] > 0)
1405+
torch._check(query.shape[1] > 0)
1406+
torch._check(query.shape[2] > 0)
1407+
torch._check(query.shape[3] > 0)
1408+
torch._check(key.shape[0] > 0)
1409+
torch._check(key.shape[1] > 0)
1410+
torch._check(key.shape[2] > 0)
1411+
torch._check(key.shape[3] > 0)
1412+
torch._check(value.shape[0] > 0)
1413+
torch._check(value.shape[1] > 0)
1414+
torch._check(value.shape[2] > 0)
1415+
torch._check(value.shape[3] > 0)
14041416
return (
14051417
torch.nn.functional.scaled_dot_product_attention(
14061418
query,
@@ -2342,25 +2354,29 @@ def forward(
23422354
**kwargs,
23432355
)
23442356
elif _is_torchdynamo_exporting():
2357+
if (
2358+
attention_interface
2359+
is transformers.integrations.sdpa_attention.sdpa_attention_forward
2360+
):
2361+
attention_interface = patched_sdpa_attention_forward
23452362

23462363
def _iteration(start_end, query_states, key_states, value_states):
2347-
a, b = start_end
2364+
a = start_end[0]
2365+
b = start_end[1]
23482366
q = query_states[:, :, a:b, :]
23492367
k = key_states[:, :, a:b, :]
23502368
v = value_states[:, :, a:b, :]
2351-
return (
2352-
attention_interface(
2353-
self,
2354-
q,
2355-
k,
2356-
v,
2357-
attention_mask=None,
2358-
scaling=self.scaling,
2359-
dropout=0.0 if not self.training else self.attention_dropout,
2360-
is_causal=False,
2361-
**kwargs,
2362-
)[0],
2363-
)
2369+
return attention_interface(
2370+
self,
2371+
q,
2372+
k,
2373+
v,
2374+
attention_mask=None,
2375+
scaling=self.scaling,
2376+
dropout=0.0 if not self.training else self.attention_dropout,
2377+
is_causal=False,
2378+
**kwargs,
2379+
)[0]
23642380

23652381
starts = cu_seqlens[:-1]
23662382
ends = cu_seqlens[1:]
@@ -2369,6 +2385,13 @@ def _iteration(start_end, query_states, key_states, value_states):
23692385
_iteration(start_end, query_states, key_states, value_states)
23702386
for start_end in starts_ends
23712387
]
2388+
# attn_outputs = torch._higher_order_ops.while_loop(
2389+
# attn_outputs = torch.ops.higher_order.while_loop(
2390+
# (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
2391+
# _iteration,
2392+
# (torch.tensor(0),
2393+
# starts_ends, query_states, key_states, value_states), tuple(),
2394+
# )
23722395
attn_output = torch.cat(attn_outputs, dim=1)
23732396
else:
23742397
# Other implementations: Process each chunk separately

0 commit comments

Comments
 (0)