Skip to content

Commit 83b3b6a

Browse files
committed
more unit test
1 parent 6eb85b7 commit 83b3b6a

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
5+
6+
7+
class TestDynamicShapes(ExtTestCase):
8+
def test_getitem_index_put1(self):
9+
class Model(torch.nn.Module):
10+
def forward(self, x, value):
11+
x = x.clone()
12+
x[:, :, :, : value.shape[-1]] = value
13+
return x
14+
15+
inputs = (torch.randn(2, 2, 3, 4), torch.randn(2, 2, 3, 3))
16+
model = Model()
17+
expected = model(*inputs)
18+
19+
onx = self.to_onnx(model, inputs, dynamic_shapes=({3: "M"}, {3: "N"}))
20+
self.dump_onnx("test_getitem_index_put1.onnx", onx)
21+
feeds = dict(zip(["x", "value"], [x.detach().cpu().numpy() for x in inputs]))
22+
ref = ExtendedReferenceEvaluator(onx, verbose=0)
23+
got = ref.run(None, feeds)[0]
24+
self.assertEqualArray(expected, got, atol=1e-5)
25+
sess = self.ort().InferenceSession(
26+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
27+
)
28+
got = sess.run(None, feeds)[0]
29+
self.assertEqualArray(expected, got, atol=1e-5)
30+
31+
32+
if __name__ == "__main__":
33+
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,18 @@ def todo(cls, f: Callable, msg: str):
756756
"Adds a todo printed when all test are run."
757757
cls._todos.append((f, msg))
758758

759+
@classmethod
760+
def ort(cls):
761+
import onnxruntime
762+
763+
return onnxruntime
764+
765+
@classmethod
766+
def to_onnx(self, *args, **kwargs):
767+
from experimental_experiment.torch_interpreter import to_onnx
768+
769+
return to_onnx(*args, **kwargs)
770+
759771
def print_model(self, model: "ModelProto"): # noqa: F821
760772
"Prints a ModelProto"
761773
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx

0 commit comments

Comments
 (0)