|
| 1 | +import unittest |
| 2 | +import torch |
| 3 | +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch |
| 4 | +from onnx_diagnostic.torch_export_patches.patch_helper import py_vmap |
| 5 | + |
| 6 | + |
| 7 | +class TestPatchHelper(ExtTestCase): |
| 8 | + def test_vmap(self): |
| 9 | + f = lambda x, y: x * y + 1 # noqa: E731 |
| 10 | + x = torch.tensor([1.0, 2.0, 3.0]) |
| 11 | + y = torch.tensor([0.1, 0.2, 0.3]) |
| 12 | + expected = torch.vmap(f)(x, y) |
| 13 | + got = py_vmap(f)(x, y) |
| 14 | + self.assertEqualArray(expected, got) |
| 15 | + |
| 16 | + @requires_torch("2.9") |
| 17 | + def test_export_vmap(self): |
| 18 | + class Model(torch.nn.Module): |
| 19 | + def forward(self, x, y): |
| 20 | + f = lambda x, y: x * y + 1 # noqa: E731 |
| 21 | + return torch.vmap(f)(x, y) |
| 22 | + |
| 23 | + x = torch.tensor([1.0, 2.0, 3.0]) |
| 24 | + y = torch.tensor([0.1, 0.2, 0.3]) |
| 25 | + DYN = torch.export.Dim.DYNAMIC |
| 26 | + torch.export.export(Model(), (x, y), ({0: DYN}, {1: DYN})) |
| 27 | + |
| 28 | + def test_export_py_vmap(self): |
| 29 | + class Model(torch.nn.Module): |
| 30 | + def forward(self, x, y): |
| 31 | + f = lambda x, y: x * y + 1 # noqa: E731 |
| 32 | + return py_vmap(f)(x, y) |
| 33 | + |
| 34 | + x = torch.tensor([1.0, 2.0, 3.0]) |
| 35 | + y = torch.tensor([0.1, 0.2, 0.3]) |
| 36 | + torch.export.export(Model(), (x, y)) |
| 37 | + |
| 38 | + def test_vmap_outdim(self): |
| 39 | + f = lambda x: x**2 # noqa: E731 |
| 40 | + x = torch.randn(2, 5) |
| 41 | + expected = torch.vmap(f, out_dims=1)(x) |
| 42 | + got = py_vmap(f, out_dims=1)(x) |
| 43 | + self.assertEqualArray(expected, got) |
| 44 | + |
| 45 | + def test_vmap_dict(self): |
| 46 | + f = lambda d: torch.dot(d["x"], d["y"]) # noqa: E731 |
| 47 | + x, y = torch.randn(2, 5), torch.randn(5) |
| 48 | + input = {"x": x, "y": y} |
| 49 | + _expected = torch.vmap(f, in_dims=({"x": 0, "y": None},))(input) |
| 50 | + self.assertRaise( |
| 51 | + lambda: py_vmap(f, in_dims=({"x": 0, "y": None},))(input), AssertionError |
| 52 | + ) |
| 53 | + # self.assertEqualArray(_expected, got) |
| 54 | + |
| 55 | + def test_vmap_tuple(self): |
| 56 | + x, y = torch.randn(2, 5), torch.randn(5) |
| 57 | + expected = torch.vmap(torch.dot, in_dims=(0, None))(x, y) |
| 58 | + got = py_vmap(torch.dot, in_dims=(0, None))(x, y) |
| 59 | + self.assertEqualArray(expected, got) |
| 60 | + |
| 61 | + |
| 62 | +if __name__ == "__main__": |
| 63 | + unittest.main(verbosity=2) |
0 commit comments