Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Change Logs
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
* :pr:`310`: splits patches into multiple files
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
* :pr:`304`, :pr:`306`, :pr:`316`: improves side-by-side comparison, creates command line sbs
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`: improves side-by-side comparison, creates command line sbs

0.8.2
+++++
Expand Down
48 changes: 48 additions & 0 deletions _unittests/ut_torch_onnx/test_sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,54 @@ def forward(self, x):
)
self.assertEqual(len(results), 5)

@hide_stdout()
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
def test_sbs_model_with_weights_custom_reset(self):
torch = self.torch

class Model(self.torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(10, 3200) # input size 10 → hidden size 32
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(3200, 1) # hidden → output
with torch.no_grad():
self.fc2.bias += 1999
self.fc1.bias += 999

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x

inputs = dict(x=self.torch.randn((5, 10), dtype=torch.float16))
ds = dict(x={0: "batch"})
model = Model()
model = model.to(torch.float16)
model(**inputs)
ep = self.torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
filename = self.get_dump_file("test_sbs_model_with_weights_custom_reset.onnx")
to_onnx(ep, exporter="custom", filename=filename)
onx = onnx.load(filename)
results = list(
run_aligned(
ep,
onx,
kwargs=inputs,
run_cls=OnnxruntimeEvaluator,
verbose=11,
use_tensor=True,
reset_names=["linear"],
),
)
df = pandas.DataFrame(list(results))
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom_reset.xlsx"))
onnx_op_type = df["onnx_op_type"].tolist()
self.assertEqual(onnx_op_type.count("reset"), 1)
self.clean_dump()


if __name__ == "__main__":
unittest.main(verbosity=2)
13 changes: 13 additions & 0 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,18 @@ def get_parser_sbs() -> ArgumentParser:
default=False,
help="First runs the whole model.",
)
parser.add_argument(
"--reset",
required=False,
default="",
help=textwrap.dedent(
"""
List of result names separated by a comma. For those results,
the side-by-side will take torch results instead of onnx results
to compute the rest of the onnx model.
"""
),
)
parser.add_argument(
"--gemmlinear",
action=BooleanOptionalAction,
Expand Down Expand Up @@ -1308,6 +1320,7 @@ def _size(name):
kwargs=mkwargs,
use_tensor=True,
gemmlinear=args.gemmlinear,
reset_names=args.reset.split(","),
exc=False,
):
data.append(obs)
Expand Down
74 changes: 71 additions & 3 deletions onnx_diagnostic/torch_onnx/sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,37 @@ def prepare_args_kwargs(

@dataclass
class RunAlignedRecord:
"""
The side-by-side ran by function :func:`run_aligned
<onnx_diagnostic.torch_onnx.sbs.run_aligned>`
yields instances of this type. If both `ep_name`
and `onnx_name` are specified, then both results
appear in the exported program (torch) and the onnx model.

:param ep_id_node: node index in the exported program
:param onnx_id_node: node index in the onnx model, -1 for an initializer
:param ep_name: result name in the exported program
:param onnx_name: result name in the onnx model, usually same as `ep_name`
except for initializer
:param ep_target: target name in the exported program producing the result
:param onnx_op_type: operator type in the onnx model producing the result
:param onnx_id_output: usually 0 unless this node has multiple output,
in that case, it is the output index
:param ep_shape_type: shape and type of the results in the exported program
:param onnx_shape_type: shape and type of the results in the onnx mode,
it should be the same as `ep_shape_type`, anything different probably
means a bug
:param err_abs: maximum absolute error for the considered result
between the exported program and the onnx model
:param err_rel: maximum relative error
:param err_dev: 0 if the device is the same, 1 if not
:param err_nan: number of nan values disagreeing
:param err_h01: number of values for which the discrepancy is above 0.1
:param ep_time_run: execution time for the exported program
:param onnx_time_run: execution time for the onnx model, that includes
the creation of the onnx model so that's probably not very usable
"""

ep_id_node: Optional[int] = None
onnx_id_node: Optional[int] = None
ep_name: Optional[str] = None
Expand Down Expand Up @@ -208,7 +239,14 @@ def set_diff(self, diff: Dict[str, Any]):

@dataclass
class StatusRunAligned:
"Information to display while running the side-by-side"
"""
Information to display while running the side-by-side

:param max_abs: maximum absolute seen so far
:param n_inf: number of infinite values seen so far
:param n_nan: number of nan values seen so for
:param yielded_nodes: number of yielded pair of nodes seen so far
"""

max_abs: float = 0.0
n_inf: int = 0
Expand All @@ -223,6 +261,7 @@ def to_str(self) -> str:
)

def update(self, err_abs: float):
"Updates all attributes with the latest measure."
if np.isinf(err_abs) or np.isnan(err_abs):
self.n_inf += 1
elif err_abs > 1e6:
Expand Down Expand Up @@ -253,6 +292,7 @@ def run_aligned(
gemmlinear: bool = False,
verbose: int = 0,
exc: bool = True,
reset_names: Optional[List[str]] = None,
) -> Iterator[RunAlignedRecord]:
"""
Runs in parallel both the exported program
Expand All @@ -274,6 +314,8 @@ def run_aligned(
``torch.nn.functional.linear(A,X,B)`` on onnx side
:param verbose: verbosity level
:param exc: stops if an exception
:param reset_names: list of names, the onnx execution takes the torch outputs instead
of its own result if the names falls into that set
:return: a list of :class:`RunAlignedRecord`

Example:
Expand Down Expand Up @@ -408,6 +450,7 @@ def forward(self, x):
-v 1 --atol=0.1 --rtol=1
"""
assert callable(run_cls), f"run_cls={run_cls} not a callable"
reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment]
str_kws = dict(with_shape=True, with_device=True)
has_cuda = any(
(isinstance(t, torch.Tensor) and t.is_cuda)
Expand Down Expand Up @@ -618,6 +661,31 @@ def _loop_onnx_node(
if tmp.err_abs is not None:
status.update(tmp.err_abs)
yield tmp
if reset_names and tmp.ep_name in reset_names:
assert (
tmp.ep_name in torch_results
), f"name {tmp.ep_name!r} set to be reset is missing in torch_results."
assert (
tmp.onnx_name in onnx_results
), f"name {tmp.onnx_name!r} set to be reset is missing in onnx_results."
onnx_results[tmp.onnx_name] = torch_results[tmp.ep_name]
tmp = _loop_cmp(
mapping_onnx_to_torch,
torch_results,
onnx_results,
o,
r,
verbose,
atol,
rtol,
i,
i_onnx,
)
if tmp is not None:
tmp.onnx_op_type = "reset"
tmp.onnx_id_output = list_node_output.index(o)
status.yielded_nodes += 1
yield tmp
already_run.add(i_onnx)

def _duplicated_values(d):
Expand Down Expand Up @@ -799,13 +867,13 @@ def _gemm_linear(node, feeds, sess):
t = torch_results[init.name]
torch_names_to_onnx_names[init.name] = init.name
elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
new_names = [
new_names = [ # type: ignore[assignment]
k
for k in rev_init_aliases[init.name]
if k in torch_results and k not in skip_mapping_torch_onnx
]
if new_names and len(new_names) == 1:
new_name = new_names[0]
new_name = new_names[0] # type: ignore[assignment, index]
t = torch_results[new_name]
if (
t.shape == tuple(init.dims)
Expand Down
Loading