Skip to content

Commit 8ed467e

Browse files
committed
fix missing import
1 parent 9c9bf00 commit 8ed467e

File tree

4 files changed

+79
-4
lines changed

4 files changed

+79
-4
lines changed

_unittests/ut_export/test_control_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import torch
44
from onnxscript import script, FLOAT, INT64
55
from onnxscript import opset18 as op
6-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
6+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test
77
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for
88
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
99
from onnx_diagnostic.export.api import to_onnx
1010

1111

1212
class TestControlFlow(ExtTestCase):
13-
@unittest.skip("not working")
13+
@never_test()
1414
def test_loop_one_research(self):
1515
class Model(torch.nn.Module):
1616
def forward(self, n_iter, x):

onnx_diagnostic/_command_lines_parser.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import contextlib
23
import json
34
import os
45
import re
@@ -625,6 +626,18 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
625626
),
626627
action=_ParseDict,
627628
)
629+
parser.add_argument(
630+
"--save-ep",
631+
default="",
632+
help=textwrap.dedent(
633+
"""
634+
saves the exported program with torch.export.save
635+
and the inputs sets with torch.save,
636+
then command line sbs can be used to look for discrepancies.
637+
"""
638+
),
639+
)
640+
628641
return parser
629642

630643

@@ -691,6 +704,7 @@ def _cmd_validate(argv: List[Any]):
691704
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
692705
),
693706
exporter_options=args.expop,
707+
save_ep=args.save_ep,
694708
)
695709
print("")
696710
print("-- summary --")
@@ -1140,7 +1154,12 @@ def get_parser_sbs() -> ArgumentParser:
11401154
"--ep",
11411155
type=str,
11421156
required=True,
1143-
help="exported program saved with torch.export.save",
1157+
help=textwrap.dedent(
1158+
"""
1159+
exported program saved with torch.export.save,
1160+
input sets saved with torch.save,
1161+
"""
1162+
),
11441163
)
11451164
parser.add_argument(
11461165
"-m",
@@ -1226,6 +1245,9 @@ def _size(name):
12261245
f"Unable to infer args, kwargs from inputs {string_type(inputs, with_shape=True)}"
12271246
)
12281247

1248+
print("-- import transformers.modeling_outputs to register serialization functions")
1249+
with contextlib.suppress(ImportError):
1250+
import transformers.modeling_outputs # noqa: F401
12291251
print(f"-- load ep {args.ep!r}")
12301252
begin = time.perf_counter()
12311253
ep = torch.export.load(args.ep)

onnx_diagnostic/torch_models/validate.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,16 @@ def _call_exporter(
671671
do_run,
672672
output_names,
673673
exporter_options,
674+
save_ep,
674675
):
676+
if save_ep and dump_folder:
677+
for name in data:
678+
if name.startswith("inputs"):
679+
if verbose:
680+
print(f"[validate_model] -- dump {name!r}")
681+
filename = os.path.join(dump_folder, f"{save_ep}.{name}.pt")
682+
torch.save(data[name], filename)
683+
675684
if exporter:
676685
expop = exporter_options or {}
677686
if verbose:
@@ -711,6 +720,7 @@ def _call_exporter(
711720
dump_folder=dump_folder,
712721
output_names=output_names,
713722
exporter_options=expop,
723+
save_ep=save_ep,
714724
)
715725
else:
716726
data["inputs_export"] = data["inputs"]
@@ -831,6 +841,7 @@ def validate_model(
831841
output_names: Optional[List[str]] = None,
832842
ort_logs: bool = False,
833843
quiet_input_sets: Optional[Set[str]] = None,
844+
save_ep: Optional[str] = None,
834845
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
835846
"""
836847
Validates a model.
@@ -889,6 +900,8 @@ def validate_model(
889900
:param ort_logs: increases onnxruntime verbosity when creating the session
890901
:param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
891902
even if quiet is False
903+
:param save_ep: if not empty, this can be used to save the input sets and
904+
the exported program
892905
:return: two dictionaries, one with some metrics,
893906
another one with whatever the function produces
894907
@@ -952,6 +965,7 @@ def validate_model(
952965
subfolder=subfolder,
953966
use_pretrained=use_pretrained,
954967
same_as_pretrained=same_as_pretrained,
968+
save_ep=save_ep,
955969
)
956970
if dump_folder:
957971
with open(dump_stats, "w") as f:
@@ -1038,6 +1052,7 @@ def _validate_model_step1(
10381052
subfolder,
10391053
use_pretrained,
10401054
same_as_pretrained,
1055+
save_ep,
10411056
):
10421057
assert not do_same or do_run, (
10431058
f"Discrepancies cannot be measured if the model is not run, "
@@ -1153,6 +1168,7 @@ def _validate_model_step1(
11531168
do_run=do_run,
11541169
output_names=output_names,
11551170
exporter_options=exporter_options,
1171+
save_ep=save_ep,
11561172
)
11571173

11581174
cont, dump_stats = _dump_onnx_model(
@@ -1426,6 +1442,7 @@ def call_exporter(
14261442
dump_folder: Optional[str] = None,
14271443
output_names: Optional[List[str]] = None,
14281444
exporter_options: Optional[Dict[str, Any]] = None,
1445+
save_ep: Optional[str] = None,
14291446
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
14301447
"""
14311448
Calls an exporter on a model;
@@ -1440,6 +1457,7 @@ def call_exporter(
14401457
:param dump_folder: to dump additional information
14411458
:param output_names: list of output names to use with the onnx exporter
14421459
:param exporter_options: exporter options
1460+
:param save_ep: saves the exported program
14431461
:return: two dictionaries, one with some metrics,
14441462
another one with whatever the function produces
14451463
"""
@@ -1456,6 +1474,8 @@ def call_exporter(
14561474
optimization=optimization,
14571475
do_run=do_run,
14581476
exporter_options=exporter_options,
1477+
save_ep=save_ep,
1478+
dump_folder=dump_folder,
14591479
)
14601480
_restore_torch_export_export(summary)
14611481
return summary, data
@@ -1469,6 +1489,8 @@ def call_exporter(
14691489
optimization=optimization,
14701490
output_names=output_names,
14711491
exporter_options=exporter_options,
1492+
dump_folder=dump_folder,
1493+
save_ep=save_ep,
14721494
)
14731495
_restore_torch_export_export(summary)
14741496
return summary, data
@@ -1483,6 +1505,7 @@ def call_exporter(
14831505
dump_folder=dump_folder,
14841506
output_names=output_names,
14851507
exporter_options=exporter_options,
1508+
save_ep=save_ep,
14861509
)
14871510
_restore_torch_export_export(summary)
14881511
return summary, data
@@ -1516,6 +1539,8 @@ def call_torch_export_export(
15161539
optimization: Optional[str] = None,
15171540
do_run: bool = False,
15181541
exporter_options: Optional[Dict[str, Any]] = None,
1542+
dump_folder: Optional[str] = None,
1543+
save_ep: Optional[str] = None,
15191544
):
15201545
"""
15211546
Exports a model with :func:`torch.export.export`.
@@ -1529,6 +1554,8 @@ def call_torch_export_export(
15291554
:param optimization: optimization to do
15301555
:param do_run: runs and compute discrepancies
15311556
:param exporter_options: additional options given to the exporter
1557+
:param dump_folder: folder where to dump the exported program
1558+
:param save_ep: to save the exported program
15321559
:return: two dictionaries, one with some metrics,
15331560
another one with whatever the function produces
15341561
"""
@@ -1604,6 +1631,12 @@ def call_torch_export_export(
16041631
print(ep)
16051632
print("[call_torch_export_export] -- End of ExportedProgram")
16061633

1634+
if dump_folder and save_ep:
1635+
fname = f"{save_ep}.pt2"
1636+
if verbose:
1637+
print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
1638+
torch.export.save(ep, os.path.join(dump_folder, fname))
1639+
16071640
if do_run:
16081641
# We check for discrepancies.
16091642
if verbose:
@@ -1880,6 +1913,8 @@ def call_torch_export_onnx(
18801913
optimization: Optional[str] = None,
18811914
output_names: Optional[List[str]] = None,
18821915
exporter_options: Optional[Dict[str, Any]] = None,
1916+
dump_folder: Optional[str] = None,
1917+
save_ep: Optional[str] = None,
18831918
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
18841919
"""
18851920
Exports a model into onnx.
@@ -1893,6 +1928,8 @@ def call_torch_export_onnx(
18931928
:param optimization: optimization to do
18941929
:param output_names: output names to use
18951930
:param exporter_options: additional options to give the exporter
1931+
:param dump_folder: to know where to dump the exported program
1932+
:param save_ep: to save the exported program
18961933
:return: two dictionaries, one with some metrics,
18971934
another one with whatever the function produces
18981935
"""
@@ -1986,6 +2023,12 @@ def call_torch_export_onnx(
19862023
return summary, data
19872024

19882025
assert epo is not None, "no onnx export was found"
2026+
if dump_folder and save_ep:
2027+
fname = f"{save_ep}.pt2"
2028+
if verbose:
2029+
print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
2030+
torch.export.save(epo.exported_program, os.path.join(dump_folder, fname))
2031+
19892032
if verbose:
19902033
print("[call_torch_export_onnx] done (export)")
19912034
data["onnx_program"] = epo
@@ -2219,6 +2262,7 @@ def call_torch_export_custom(
22192262
dump_folder: Optional[str] = None,
22202263
output_names: Optional[List[str]] = None,
22212264
exporter_options: Optional[Dict[str, Any]] = None,
2265+
save_ep: Optional[str] = None,
22222266
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
22232267
"""
22242268
Exports a model into onnx.
@@ -2233,6 +2277,7 @@ def call_torch_export_custom(
22332277
:param dump_folder: to store additional information
22342278
:param output_names: list of output names to use
22352279
:param exporter_options: additional exporter options
2280+
:param save_ep: to save the exported program
22362281
:return: two dictionaries, one with some metrics,
22372282
another one with whatever the function produces
22382283
"""
@@ -2345,7 +2390,11 @@ def call_torch_export_custom(
23452390
export_options = ExportOptions(
23462391
strict=strict,
23472392
decomposition_table=decomposition_table,
2348-
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
2393+
save_ep=(
2394+
(os.path.join(dump_folder, f"{exporter}.ep"), 2**35 if save_ep else 2**18)
2395+
if dump_folder
2396+
else None
2397+
),
23492398
**exporter_options,
23502399
)
23512400
options = OptimizationOptions(patterns=optimization) if optimization else None

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,8 @@ def _loop_cmp(
736736
list_node_output = list(node.output)
737737
node_output = [o for o in list_node_output if o]
738738
for o, r in zip(node_output, res):
739+
if r is None or o is None:
740+
continue
739741
tmp = _loop_cmp(
740742
mapping_onnx_to_torch,
741743
torch_results,
@@ -787,6 +789,8 @@ def _loop_cmp(
787789
list_node_output = list(node.output)
788790
node_output = [o for o in list_node_output if o]
789791
for o, r in zip(node_output, res):
792+
if r is None or o is None:
793+
continue
790794
tmp = _loop_cmp(
791795
mapping_onnx_to_torch,
792796
torch_results,

0 commit comments

Comments
 (0)