Skip to content

Commit 509cfe4

Browse files
committed
Add a unit test about an issue
1 parent 43d1e2e commit 509cfe4

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
import numpy as np
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
6+
7+
class TestIssues2025(ExtTestCase):
8+
def test_issue_158786_qwen2vl(self):
9+
# https://github.com/pytorch/pytorch/issues/158786
10+
class Model(torch.nn.Module):
11+
def __init__(self):
12+
super().__init__()
13+
self.spatial_merge_size = 2 # Default
14+
15+
def forward(self, a):
16+
pos_ids = []
17+
for t, h, w in a:
18+
t = t.item()
19+
h = h.item()
20+
w = w.item()
21+
torch._constrain_as_size(t)
22+
torch._constrain_as_size(h)
23+
torch._constrain_as_size(w)
24+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
25+
hpos_ids = hpos_ids.reshape(
26+
h // self.spatial_merge_size,
27+
self.spatial_merge_size,
28+
w // self.spatial_merge_size,
29+
self.spatial_merge_size,
30+
)
31+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
32+
hpos_ids = hpos_ids.flatten()
33+
34+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
35+
wpos_ids = wpos_ids.reshape(
36+
h // self.spatial_merge_size,
37+
self.spatial_merge_size,
38+
w // self.spatial_merge_size,
39+
self.spatial_merge_size,
40+
)
41+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
42+
wpos_ids = wpos_ids.flatten()
43+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
44+
pos_ids = torch.cat(pos_ids, dim=0)
45+
return pos_ids
46+
47+
model = Model()
48+
inputs = torch.tensor(np.array([1, 98, 146]).reshape(1, 3))
49+
ep = torch.export.export(model, (inputs,))
50+
self.assertIn("torch.ops.aten.cat.default", str(ep))
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)