Skip to content

Commit 3e92166

Browse files
committed
mystery
1 parent 4515651 commit 3e92166

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

_unittests/ut_torch_onnx/test_discrepancies.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,53 @@
22
import unittest
33
import numpy as np
44
import onnx
5-
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
6+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
67

78

89
class TestDiscrepancies(ExtTestCase):
10+
@ignore_warnings(DeprecationWarning)
911
def test_attention_opset15_in_a_loop(self):
12+
import torch
13+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_attention import ( # noqa: E501
14+
patched_sdpa_attention_forward,
15+
)
16+
17+
def qwen_sdpa_attention(
18+
query_states: torch.Tensor,
19+
key_states: torch.Tensor,
20+
value_states: torch.Tensor,
21+
cu_seqlens: torch.Tensor,
22+
scaling: float = 0,
23+
num_heads: int = 16,
24+
) -> torch.Tensor:
25+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
26+
splits = [
27+
torch.split(tensor, lengths.tolist(), dim=2)
28+
for tensor in (query_states, key_states, value_states)
29+
]
30+
31+
attn_outputs = [
32+
patched_sdpa_attention_forward(
33+
None,
34+
q,
35+
k,
36+
v,
37+
attention_mask=None,
38+
scaling=scaling,
39+
dropout=0.0,
40+
is_causal=False,
41+
)[0]
42+
for q, k, v in zip(*splits)
43+
]
44+
attn_output = torch.cat(attn_outputs, dim=1)
45+
return attn_output
46+
1047
model = onnx.load(
1148
os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx")
1249
)
1350
sess = self.check_ort(model)
51+
1452
feeds = dict(
1553
c_lifted_tensor_0=np.array([0], dtype=np.int64),
1654
cat_2=np.array(
@@ -48,9 +86,41 @@ def test_attention_opset15_in_a_loop(self):
4886
unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32),
4987
unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32),
5088
)
89+
90+
dummy_inputs = os.path.join(
91+
os.path.dirname(__file__),
92+
"..",
93+
"..",
94+
"dump_test",
95+
"replay",
96+
"qwen_sdpa_attention_loopa24",
97+
"onnx_inputs.pt",
98+
)
99+
if os.path.exists(dummy_inputs):
100+
print("-- use dummy inputs")
101+
feeds = {k: v.detach().cpu().numpy() for k, v in torch.load(dummy_inputs).items()}
102+
for k, v in feeds.items():
103+
print(f"-- {k}: {self.string_type(v, with_shape=True, with_min_max=True)}")
104+
105+
# feeds["cat_2"] = np.array([0, 1292], dtype=np.int64)
51106
got = sess.run(None, feeds)
52107
self.assertEqual(len(got), 1)
53108
self.assertEqual((1, 1292, 16, 80), got[0].shape)
109+
expected = qwen_sdpa_attention(
110+
torch.from_numpy(feeds["unsqueeze_4"]),
111+
torch.from_numpy(feeds["unsqueeze_5"]),
112+
torch.from_numpy(feeds["unsqueeze_6"]),
113+
torch.from_numpy(feeds["cat_2"]),
114+
scaling=0.11180339753627777,
115+
num_heads=16,
116+
)
117+
self.assertEqualArray(expected, got[0], atol=1e-5)
118+
119+
tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
120+
ev = OnnxruntimeEvaluator(model)
121+
got2 = ev.run(None, tfeeds)
122+
self.assertEqual(len(got2), 1)
123+
self.assertEqualArray(got[0], got2[0], atol=1e-5)
54124

55125

56126
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def patched_sdpa_attention_forward(
118118
torch._check(value.shape[1] > 0)
119119
torch._check(value.shape[2] > 0)
120120
torch._check(value.shape[3] > 0)
121+
121122
return (
122123
torch.nn.functional.scaled_dot_product_attention(
123124
query,

0 commit comments

Comments
 (0)