Skip to content

Commit c83499c

Browse files
committed
first version of sbs
1 parent c0df801 commit c83499c

File tree

6 files changed

+542
-0
lines changed

6 files changed

+542
-0
lines changed

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ API of onnx_diagnostic
1111
reference/index
1212
torch_export_patches/index
1313
torch_models/index
14+
torch_onnx/index
1415

1516
.. toctree::
1617
:maxdepth: 1

_doc/api/torch_onnx/index.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
onnx_diagnostic.torch_onnx
2+
==========================
3+
4+
.. toctree::
5+
:maxdepth: 1
6+
:caption: submodules
7+
8+
sbs
9+
10+
.. automodule:: onnx_diagnostic.torch_onnx
11+
:members:
12+
:no-undoc-members:

_doc/api/torch_onnx/sbs.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.torch_onnx.sbs
3+
==============================
4+
5+
.. automodule:: onnx_diagnostic.torch_onnx.sbs
6+
:members:
7+
:no-undoc-members:
8+
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
3+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
4+
from onnx_diagnostic.torch_onnx.sbs import run_aligned
5+
6+
try:
7+
from experimental_experiment.torch_interpreter import to_onnx
8+
except ImportError:
9+
to_onnx = None
10+
11+
12+
class TestSideBySide(ExtTestCase):
13+
14+
@hide_stdout()
15+
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
16+
def test_ep_onnx_sync_exp(self):
17+
import torch
18+
19+
class Model(torch.nn.Module):
20+
def forward(self, x):
21+
ry = x.abs()
22+
rz = ry.exp()
23+
rw = rz + 1
24+
ru = rw.log() + rw
25+
return ru
26+
27+
x = torch.randn((5, 4))
28+
Model()(x)
29+
ep = torch.export.export(
30+
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
31+
)
32+
onx = to_onnx(ep)
33+
results = list(
34+
run_aligned(
35+
ep,
36+
onx,
37+
(x,),
38+
check_conversion_cls=dict(
39+
cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5
40+
),
41+
verbose=1,
42+
),
43+
)
44+
self.assertEqual(len(results), 4)
45+
46+
@hide_stdout()
47+
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
48+
def test_ep_onnx_sync(self):
49+
import torch
50+
51+
class Model(torch.nn.Module):
52+
def forward(self, x):
53+
ry = x.abs()
54+
rz = ry.exp()
55+
rw = rz + 1
56+
ru = rw.log() + rw
57+
return ru
58+
59+
x = torch.randn((5, 4))
60+
Model()(x)
61+
ep = torch.export.export(
62+
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
63+
)
64+
onx = torch.onnx.export(
65+
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True
66+
).model_proto
67+
results = list(
68+
run_aligned(
69+
ep,
70+
onx,
71+
(x,),
72+
check_conversion_cls=dict(
73+
cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5
74+
),
75+
verbose=1,
76+
),
77+
)
78+
self.assertEqual(len(results), 4)
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main(verbosity=2)

onnx_diagnostic/torch_onnx/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)