|
| 1 | +import unittest |
| 2 | +import torch |
| 3 | +from onnx_diagnostic.ext_test_case import ExtTestCase |
| 4 | +from onnx_diagnostic.reference import ExtendedReferenceEvaluator |
| 5 | + |
| 6 | + |
| 7 | +class TestDynamicShapes(ExtTestCase): |
| 8 | + def test_getitem_index_put1(self): |
| 9 | + class Model(torch.nn.Module): |
| 10 | + def forward(self, x, value): |
| 11 | + x = x.clone() |
| 12 | + x[:, :, :, : value.shape[-1]] = value |
| 13 | + return x |
| 14 | + |
| 15 | + inputs = (torch.randn(2, 2, 3, 4), torch.randn(2, 2, 3, 3)) |
| 16 | + model = Model() |
| 17 | + expected = model(*inputs) |
| 18 | + |
| 19 | + onx = self.to_onnx(model, inputs, dynamic_shapes=({3: "M"}, {3: "N"})) |
| 20 | + self.dump_onnx("test_getitem_index_put1.onnx", onx) |
| 21 | + feeds = dict(zip(["x", "value"], [x.detach().cpu().numpy() for x in inputs])) |
| 22 | + ref = ExtendedReferenceEvaluator(onx, verbose=0) |
| 23 | + got = ref.run(None, feeds)[0] |
| 24 | + self.assertEqualArray(expected, got, atol=1e-5) |
| 25 | + sess = self.ort().InferenceSession( |
| 26 | + onx.SerializeToString(), providers=["CPUExecutionProvider"] |
| 27 | + ) |
| 28 | + got = sess.run(None, feeds)[0] |
| 29 | + self.assertEqualArray(expected, got, atol=1e-5) |
| 30 | + |
| 31 | + |
| 32 | +if __name__ == "__main__": |
| 33 | + unittest.main(verbosity=2) |
0 commit comments