|
| 1 | +import unittest |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | +from onnx_diagnostic.ext_test_case import ExtTestCase |
| 5 | + |
| 6 | + |
| 7 | +class TestIssues2025(ExtTestCase): |
| 8 | + def test_issue_158786_qwen2vl(self): |
| 9 | + # https://github.com/pytorch/pytorch/issues/158786 |
| 10 | + class Model(torch.nn.Module): |
| 11 | + def __init__(self): |
| 12 | + super().__init__() |
| 13 | + self.spatial_merge_size = 2 # Default |
| 14 | + |
| 15 | + def forward(self, a): |
| 16 | + pos_ids = [] |
| 17 | + for t, h, w in a: |
| 18 | + t = t.item() |
| 19 | + h = h.item() |
| 20 | + w = w.item() |
| 21 | + torch._constrain_as_size(t) |
| 22 | + torch._constrain_as_size(h) |
| 23 | + torch._constrain_as_size(w) |
| 24 | + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
| 25 | + hpos_ids = hpos_ids.reshape( |
| 26 | + h // self.spatial_merge_size, |
| 27 | + self.spatial_merge_size, |
| 28 | + w // self.spatial_merge_size, |
| 29 | + self.spatial_merge_size, |
| 30 | + ) |
| 31 | + hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
| 32 | + hpos_ids = hpos_ids.flatten() |
| 33 | + |
| 34 | + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
| 35 | + wpos_ids = wpos_ids.reshape( |
| 36 | + h // self.spatial_merge_size, |
| 37 | + self.spatial_merge_size, |
| 38 | + w // self.spatial_merge_size, |
| 39 | + self.spatial_merge_size, |
| 40 | + ) |
| 41 | + wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
| 42 | + wpos_ids = wpos_ids.flatten() |
| 43 | + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
| 44 | + pos_ids = torch.cat(pos_ids, dim=0) |
| 45 | + return pos_ids |
| 46 | + |
| 47 | + model = Model() |
| 48 | + inputs = torch.tensor(np.array([1, 98, 146]).reshape(1, 3)) |
| 49 | + ep = torch.export.export(model, (inputs,)) |
| 50 | + self.assertIn("torch.ops.aten.cat.default", str(ep)) |
| 51 | + |
| 52 | + |
| 53 | +if __name__ == "__main__": |
| 54 | + unittest.main(verbosity=2) |
0 commit comments