Skip to content

Commit cd3ff9f

Browse files
committed
Adds discrepancies with the exporter program
1 parent 0521c46 commit cd3ff9f

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def _config_reduction(config, task):
127127
optimize=True,
128128
)
129129

130+
pt2_file = f"{fileep}.ep.pt2"
131+
# self.assertExists(pt2_file)
132+
# ep = torch.export.load(pt2_file)
133+
# diff = self.max_diff(ep.module()(**export_inputs), model.visual(**export_inputs))
134+
# print("----------- diff", diff)
130135
self.assert_onnx_disc(
131136
f"test_imagetext2text_qwen_2_5_vl_instruct_visual.{device}.{dtype}.{exporter}",
132137
filename,
@@ -142,6 +147,7 @@ def _config_reduction(config, task):
142147
atol=0.02,
143148
rtol=10,
144149
ort_optimized_graph=False,
150+
ep=pt2_file,
145151
)
146152

147153

onnx_diagnostic/ext_test_case.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,7 @@ def assert_onnx_disc(
11991199
expected: Optional[Any] = None,
12001200
use_ort: bool = False,
12011201
ort_optimized_graph: bool = False,
1202+
ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
12021203
**kwargs,
12031204
):
12041205
"""
@@ -1218,6 +1219,7 @@ def assert_onnx_disc(
12181219
:param copy_inputs: to copy the inputs
12191220
:param use_ort: use :class:`onnxruntime.InferenceSession`
12201221
:param ort_optimized_graph: dumps the optimized onnxruntime graph
1222+
:param ep: exported program (or saved exported program)
12211223
:param kwargs: arguments sent to
12221224
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
12231225
"""
@@ -1245,6 +1247,7 @@ def assert_onnx_disc(
12451247
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
12461248
if verbose:
12471249
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1250+
12481251
if use_ort:
12491252
assert isinstance(
12501253
proto, onnx.ModelProto
@@ -1275,6 +1278,7 @@ def assert_onnx_disc(
12751278
got = sess.run(None, feeds)
12761279
if verbose:
12771280
print(f"[{vname}] compute expected values")
1281+
12781282
if expected is None:
12791283
if copy_inputs:
12801284
expected = (
@@ -1284,9 +1288,45 @@ def assert_onnx_disc(
12841288
)
12851289
else:
12861290
expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1291+
12871292
if verbose:
12881293
print(f"[{vname}] expected {string_type(expected, **kws)}")
12891294
print(f"[{vname}] obtained {string_type(got, **kws)}")
1295+
1296+
if ep:
1297+
if isinstance(ep, str):
1298+
if verbose:
1299+
print(f"[{vname}] load exported program {ep!r}")
1300+
import torch
1301+
1302+
ep = torch.export.load(ep)
1303+
ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
1304+
ep_model = ep.module()
1305+
ep_expected = (
1306+
ep_model(*copy.deepcopy(ep_inputs))
1307+
if isinstance(ep_inputs, tuple)
1308+
else ep_model(**copy.deepcopy(ep_inputs))
1309+
)
1310+
if verbose:
1311+
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1312+
ep_diff = max_diff(expected, ep_expected)
1313+
if verbose:
1314+
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1315+
assert (
1316+
isinstance(ep_diff["abs"], float)
1317+
and isinstance(ep_diff["rel"], float)
1318+
and not numpy.isnan(ep_diff["abs"])
1319+
and ep_diff["abs"] <= atol
1320+
and not numpy.isnan(ep_diff["rel"])
1321+
and ep_diff["rel"] <= rtol
1322+
), (
1323+
f"discrepancies in {test_name!r} between the model "
1324+
f"and the exported model diff={string_diff(ep_diff)}"
1325+
)
1326+
ep_nx_diff = max_diff(ep_expected, got, flatten=True)
1327+
if verbose:
1328+
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1329+
12901330
diff = max_diff(expected, got, flatten=True)
12911331
if verbose:
12921332
print(f"[{vname}] diff {string_diff(diff)}")
@@ -1297,7 +1337,10 @@ def assert_onnx_disc(
12971337
and diff["abs"] <= atol
12981338
and not numpy.isnan(diff["rel"])
12991339
and diff["rel"] <= rtol
1300-
), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1340+
), (
1341+
f"discrepancies in {test_name!r} between the model and "
1342+
f"the onnx model diff={string_diff(diff)}"
1343+
)
13011344

13021345
def _debug(self):
13031346
"Tells if DEBUG=1 is set up."
@@ -1308,6 +1351,11 @@ def string_type(self, *args, **kwargs):
13081351

13091352
return string_type(*args, **kwargs)
13101353

1354+
def max_diff(self, *args, **kwargs):
1355+
from .helpers import max_diff
1356+
1357+
return max_diff(*args, **kwargs)
1358+
13111359
def subloop(self, *args, verbose: int = 0):
13121360
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
13131361
if len(args) == 1:

0 commit comments

Comments
 (0)