Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Change Logs
0.8.3
+++++

* :pr:`304`, :pr:`306`: improves side-by-side comparison
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
* :pr:`304`, :pr:`306`: improves side-by-side comparison, creates command line sbs

0.8.2
+++++
Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_export/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import torch
from onnxscript import script, FLOAT, INT64
from onnxscript import opset18 as op
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
from onnx_diagnostic.export.api import to_onnx


class TestControlFlow(ExtTestCase):
@unittest.skip("not working")
@never_test()
def test_loop_one_research(self):
class Model(torch.nn.Module):
def forward(self, n_iter, x):
Expand Down
24 changes: 23 additions & 1 deletion onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import contextlib
import json
import os
import re
Expand Down Expand Up @@ -625,6 +626,18 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
),
action=_ParseDict,
)
parser.add_argument(
"--save-ep",
default="",
help=textwrap.dedent(
"""
saves the exported program with torch.export.save
and the inputs sets with torch.save,
then command line sbs can be used to look for discrepancies.
"""
),
)

return parser


Expand Down Expand Up @@ -691,6 +704,7 @@ def _cmd_validate(argv: List[Any]):
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
),
exporter_options=args.expop,
save_ep=args.save_ep,
)
print("")
print("-- summary --")
Expand Down Expand Up @@ -1140,7 +1154,12 @@ def get_parser_sbs() -> ArgumentParser:
"--ep",
type=str,
required=True,
help="exported program saved with torch.export.save",
help=textwrap.dedent(
"""
exported program saved with torch.export.save,
input sets saved with torch.save,
"""
),
)
parser.add_argument(
"-m",
Expand Down Expand Up @@ -1226,6 +1245,9 @@ def _size(name):
f"Unable to infer args, kwargs from inputs {string_type(inputs, with_shape=True)}"
)

print("-- import transformers.modeling_outputs to register serialization functions")
with contextlib.suppress(ImportError):
import transformers.modeling_outputs # noqa: F401
print(f"-- load ep {args.ep!r}")
begin = time.perf_counter()
ep = torch.export.load(args.ep)
Expand Down
51 changes: 50 additions & 1 deletion onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,16 @@ def _call_exporter(
do_run,
output_names,
exporter_options,
save_ep,
):
if save_ep and dump_folder:
for name in data:
if name.startswith("inputs"):
if verbose:
print(f"[validate_model] -- dump {name!r}")
filename = os.path.join(dump_folder, f"{save_ep}.{name}.pt")
torch.save(data[name], filename)

if exporter:
expop = exporter_options or {}
if verbose:
Expand Down Expand Up @@ -711,6 +720,7 @@ def _call_exporter(
dump_folder=dump_folder,
output_names=output_names,
exporter_options=expop,
save_ep=save_ep,
)
else:
data["inputs_export"] = data["inputs"]
Expand Down Expand Up @@ -831,6 +841,7 @@ def validate_model(
output_names: Optional[List[str]] = None,
ort_logs: bool = False,
quiet_input_sets: Optional[Set[str]] = None,
save_ep: Optional[str] = None,
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
"""
Validates a model.
Expand Down Expand Up @@ -889,6 +900,8 @@ def validate_model(
:param ort_logs: increases onnxruntime verbosity when creating the session
:param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
even if quiet is False
:param save_ep: if not empty, this can be used to save the input sets and
the exported program
:return: two dictionaries, one with some metrics,
another one with whatever the function produces

Expand Down Expand Up @@ -952,6 +965,7 @@ def validate_model(
subfolder=subfolder,
use_pretrained=use_pretrained,
same_as_pretrained=same_as_pretrained,
save_ep=save_ep,
)
if dump_folder:
with open(dump_stats, "w") as f:
Expand Down Expand Up @@ -1038,6 +1052,7 @@ def _validate_model_step1(
subfolder,
use_pretrained,
same_as_pretrained,
save_ep,
):
assert not do_same or do_run, (
f"Discrepancies cannot be measured if the model is not run, "
Expand Down Expand Up @@ -1153,6 +1168,7 @@ def _validate_model_step1(
do_run=do_run,
output_names=output_names,
exporter_options=exporter_options,
save_ep=save_ep,
)

cont, dump_stats = _dump_onnx_model(
Expand Down Expand Up @@ -1426,6 +1442,7 @@ def call_exporter(
dump_folder: Optional[str] = None,
output_names: Optional[List[str]] = None,
exporter_options: Optional[Dict[str, Any]] = None,
save_ep: Optional[str] = None,
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
"""
Calls an exporter on a model;
Expand All @@ -1440,6 +1457,7 @@ def call_exporter(
:param dump_folder: to dump additional information
:param output_names: list of output names to use with the onnx exporter
:param exporter_options: exporter options
:param save_ep: saves the exported program
:return: two dictionaries, one with some metrics,
another one with whatever the function produces
"""
Expand All @@ -1456,6 +1474,8 @@ def call_exporter(
optimization=optimization,
do_run=do_run,
exporter_options=exporter_options,
save_ep=save_ep,
dump_folder=dump_folder,
)
_restore_torch_export_export(summary)
return summary, data
Expand All @@ -1469,6 +1489,8 @@ def call_exporter(
optimization=optimization,
output_names=output_names,
exporter_options=exporter_options,
dump_folder=dump_folder,
save_ep=save_ep,
)
_restore_torch_export_export(summary)
return summary, data
Expand All @@ -1483,6 +1505,7 @@ def call_exporter(
dump_folder=dump_folder,
output_names=output_names,
exporter_options=exporter_options,
save_ep=save_ep,
)
_restore_torch_export_export(summary)
return summary, data
Expand Down Expand Up @@ -1516,6 +1539,8 @@ def call_torch_export_export(
optimization: Optional[str] = None,
do_run: bool = False,
exporter_options: Optional[Dict[str, Any]] = None,
dump_folder: Optional[str] = None,
save_ep: Optional[str] = None,
):
"""
Exports a model with :func:`torch.export.export`.
Expand All @@ -1529,6 +1554,8 @@ def call_torch_export_export(
:param optimization: optimization to do
:param do_run: runs and compute discrepancies
:param exporter_options: additional options given to the exporter
:param dump_folder: folder where to dump the exported program
:param save_ep: to save the exported program
:return: two dictionaries, one with some metrics,
another one with whatever the function produces
"""
Expand Down Expand Up @@ -1604,6 +1631,12 @@ def call_torch_export_export(
print(ep)
print("[call_torch_export_export] -- End of ExportedProgram")

if dump_folder and save_ep:
fname = f"{save_ep}.pt2"
if verbose:
print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
torch.export.save(ep, os.path.join(dump_folder, fname))

if do_run:
# We check for discrepancies.
if verbose:
Expand Down Expand Up @@ -1880,6 +1913,8 @@ def call_torch_export_onnx(
optimization: Optional[str] = None,
output_names: Optional[List[str]] = None,
exporter_options: Optional[Dict[str, Any]] = None,
dump_folder: Optional[str] = None,
save_ep: Optional[str] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Exports a model into onnx.
Expand All @@ -1893,6 +1928,8 @@ def call_torch_export_onnx(
:param optimization: optimization to do
:param output_names: output names to use
:param exporter_options: additional options to give the exporter
:param dump_folder: to know where to dump the exported program
:param save_ep: to save the exported program
:return: two dictionaries, one with some metrics,
another one with whatever the function produces
"""
Expand Down Expand Up @@ -1986,6 +2023,12 @@ def call_torch_export_onnx(
return summary, data

assert epo is not None, "no onnx export was found"
if dump_folder and save_ep:
fname = f"{save_ep}.pt2"
if verbose:
print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
torch.export.save(epo.exported_program, os.path.join(dump_folder, fname))

if verbose:
print("[call_torch_export_onnx] done (export)")
data["onnx_program"] = epo
Expand Down Expand Up @@ -2219,6 +2262,7 @@ def call_torch_export_custom(
dump_folder: Optional[str] = None,
output_names: Optional[List[str]] = None,
exporter_options: Optional[Dict[str, Any]] = None,
save_ep: Optional[str] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Exports a model into onnx.
Expand All @@ -2233,6 +2277,7 @@ def call_torch_export_custom(
:param dump_folder: to store additional information
:param output_names: list of output names to use
:param exporter_options: additional exporter options
:param save_ep: to save the exported program
:return: two dictionaries, one with some metrics,
another one with whatever the function produces
"""
Expand Down Expand Up @@ -2345,7 +2390,11 @@ def call_torch_export_custom(
export_options = ExportOptions(
strict=strict,
decomposition_table=decomposition_table,
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
save_ep=(
(os.path.join(dump_folder, f"{exporter}.ep"), 2**35 if save_ep else 2**18)
if dump_folder
else None
),
**exporter_options,
)
options = OptimizationOptions(patterns=optimization) if optimization else None
Expand Down
4 changes: 4 additions & 0 deletions onnx_diagnostic/torch_onnx/sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,8 @@ def _loop_cmp(
list_node_output = list(node.output)
node_output = [o for o in list_node_output if o]
for o, r in zip(node_output, res):
if r is None or o is None:
continue
tmp = _loop_cmp(
mapping_onnx_to_torch,
torch_results,
Expand Down Expand Up @@ -787,6 +789,8 @@ def _loop_cmp(
list_node_output = list(node.output)
node_output = [o for o in list_node_output if o]
for o, r in zip(node_output, res):
if r is None or o is None:
continue
tmp = _loop_cmp(
mapping_onnx_to_torch,
torch_results,
Expand Down
Loading