Skip to content

Commit 1d7735f

Browse files
authored
Adds replay functionality in side-by-side (#318)
* Adds replay functionality in side-by-side * gemm * add command lines
1 parent 4ca6c9d commit 1d7735f

File tree

7 files changed

+446
-50
lines changed

7 files changed

+446
-50
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Change Logs
88
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
99
* :pr:`310`: splits patches into multiple files
1010
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
11-
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`: improves side-by-side comparison, creates command line sbs
11+
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs
1212

1313
0.8.2
1414
+++++

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
enumerate_results,
2020
shadowing_names,
2121
onnx_dtype_name,
22+
extract_subset_of_nodes,
23+
make_submodel,
2224
)
2325

2426

@@ -476,6 +478,58 @@ def test_onnx_dtype_name(self):
476478
self.assertRaise(lambda: onnx_dtype_name(1000), ValueError)
477479
self.assertEqual(onnx_dtype_name(1000, exc=False), "UNEXPECTED")
478480

481+
def test_extract_subset_of_nodes(self):
482+
model = oh.make_model(
483+
oh.make_graph(
484+
[
485+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
486+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
487+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
488+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
489+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
490+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
491+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
492+
],
493+
"dummy",
494+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
495+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
496+
[
497+
onh.from_array(
498+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
499+
),
500+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
501+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
502+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
503+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
504+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
505+
],
506+
),
507+
opset_imports=[oh.make_opsetid("", 18)],
508+
ir_version=9,
509+
)
510+
submodel = extract_subset_of_nodes(model, "xm", cut_points={"Y", "xu2", "xm1"})
511+
op_types = [n.op_type for n in submodel]
512+
self.assertEqual(["Reshape", "Cast", "MatMul"], op_types)
513+
514+
def _type_rank_fn(name):
515+
if name in {"Y", "xu2"}:
516+
return TensorProto.FLOAT, 4
517+
if name in {"xm1", "xm"}:
518+
return TensorProto.FLOAT, 3
519+
if name == "shape2":
520+
return TensorProto.INT64, 1
521+
raise AssertionError(f"unexpected name={name!r}")
522+
523+
new_model = make_submodel(
524+
submodel,
525+
ir_version=model.ir_version,
526+
opset_imports=model.opset_import,
527+
type_rank_fn=_type_rank_fn,
528+
output_names=["xm"],
529+
)
530+
check_model(new_model)
531+
self.check_ort(new_model)
532+
479533

480534
if __name__ == "__main__":
481535
unittest.main(verbosity=2)

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
1212
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
13-
from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord
13+
from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord, ReplayConfiguration
1414
from onnx_diagnostic.export.api import to_onnx
1515

1616

@@ -23,7 +23,7 @@ def setUpClass(cls):
2323

2424
def test_run_aligned_record(self):
2525
r = RunAlignedRecord(
26-
ep_id_node=-1,
26+
ep_id_node=1,
2727
onnx_id_node=-1,
2828
ep_name="A",
2929
onnx_name="B",
@@ -512,6 +512,56 @@ def forward(self, x):
512512
self.assertEqual(onnx_op_type.count("reset"), 1)
513513
self.clean_dump()
514514

515+
@hide_stdout()
516+
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
517+
def test_sbs_replay(self):
518+
torch = self.torch
519+
520+
class Model(self.torch.nn.Module):
521+
def __init__(self):
522+
super(Model, self).__init__()
523+
self.fc1 = torch.nn.Linear(10, 3200) # input size 10 → hidden size 32
524+
self.relu = torch.nn.ReLU()
525+
self.fc2 = torch.nn.Linear(3200, 1) # hidden → output
526+
with torch.no_grad():
527+
self.fc2.bias += 1999
528+
self.fc1.bias += 999
529+
530+
def forward(self, x):
531+
x = self.relu(self.fc1(x))
532+
x = self.fc2(x)
533+
return x
534+
535+
inputs = dict(x=self.torch.randn((5, 10), dtype=torch.float16))
536+
ds = dict(x={0: "batch"})
537+
model = Model()
538+
model = model.to(torch.float16)
539+
model(**inputs)
540+
ep = self.torch.export.export(
541+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
542+
)
543+
filename = self.get_dump_file("test_sbs_replay.onnx")
544+
dump_folder = self.get_dump_folder("test_sbs_replay_linear")
545+
to_onnx(ep, exporter="custom", filename=filename)
546+
onx = onnx.load(filename)
547+
results = list(
548+
run_aligned(
549+
ep,
550+
onx,
551+
kwargs=inputs,
552+
run_cls=OnnxruntimeEvaluator,
553+
verbose=11,
554+
use_tensor=True,
555+
replay_configuration=ReplayConfiguration(
556+
dump_folder=dump_folder, selected_op_types={"Gemm"}
557+
),
558+
),
559+
)
560+
df = pandas.DataFrame(list(results))
561+
df.to_excel(self.get_dump_file("test_sbs_replay.xlsx"))
562+
print(df)
563+
# self.clean_dump()
564+
515565

516566
if __name__ == "__main__":
517567
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def forward(self, x):
112112
input_file = self.get_dump_file("test_h_parser_sbs.inputs.pt")
113113
ep_file = self.get_dump_file("test_h_parser_sbs.ep")
114114
onnx_file = self.get_dump_file("test_h_parser_sbs.model.onnx")
115+
replay_foler = self.get_dump_folder("test_h_parser_sbs.replay")
115116
torch.save(inputs, input_file)
116117
to_onnx(
117118
Model(),
@@ -139,6 +140,10 @@ def forward(self, x):
139140
output,
140141
"-m",
141142
onnx_file,
143+
"-t",
144+
"Gemm",
145+
"-f",
146+
replay_foler,
142147
]
143148
)
144149
text = st.getvalue()

onnx_diagnostic/_command_lines_parser.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,14 @@ def get_parser_sbs() -> ArgumentParser:
11401140
- torch.export.save(ep: torch.export.ExportedProgram)
11411141
- torch.save(**inputs)
11421142
- onnx.save(...)
1143+
1144+
The Replay functionality is just a way to investigates a part of a model.
1145+
It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
1146+
which shares its inputs with the exported program.
1147+
This is used to investigate the discrepancies between the torch
1148+
model (through the exported program) and its onnx conversion.
1149+
This functionality dumps everything it can to disk
1150+
so that it be replayed in a separate process.
11431151
"""
11441152
),
11451153
)
@@ -1222,10 +1230,33 @@ def get_parser_sbs() -> ArgumentParser:
12221230
),
12231231
)
12241232
parser.add_argument(
1225-
"--gemmlinear",
1226-
action=BooleanOptionalAction,
1227-
default=False,
1228-
help="Replaces Gemm(A,X.T,B) by torch...linear(A,X,B) on onnx side",
1233+
"-s",
1234+
"--replay-threshold",
1235+
type=float,
1236+
required=False,
1237+
default=1e6,
1238+
help="Triggers the replay if the discrepancies are higher than this value.",
1239+
)
1240+
parser.add_argument(
1241+
"-n",
1242+
"--replay-names",
1243+
required=False,
1244+
default="",
1245+
help="Triggers the replay if a result name is in this set of values (comma separated)",
1246+
)
1247+
parser.add_argument(
1248+
"-t",
1249+
"--replay-op-types",
1250+
required=False,
1251+
default="",
1252+
help="Triggers the replay if an onnx type is in this set of values (comma separated)",
1253+
)
1254+
parser.add_argument(
1255+
"-f",
1256+
"--replay-folder",
1257+
required=False,
1258+
default="replay",
1259+
help="If the replay is triggered, this defines the folder where everything is dumped.",
12291260
)
12301261

12311262
return parser
@@ -1235,7 +1266,7 @@ def _cmd_sbs(argv: List[Any]):
12351266
import pandas
12361267
import torch
12371268
from .helpers import flatten_object, max_diff, string_diff, string_type
1238-
from .torch_onnx.sbs import run_aligned
1269+
from .torch_onnx.sbs import run_aligned, ReplayConfiguration
12391270
from .reference import OnnxruntimeEvaluator
12401271

12411272
parser = get_parser_sbs()
@@ -1306,6 +1337,17 @@ def _size(name):
13061337
onx = onnx.load(args.onnx)
13071338
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
13081339

1340+
replay_configuration = None
1341+
if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types:
1342+
replay_configuration = ReplayConfiguration(
1343+
threshold=args.replay_threshold,
1344+
selected_names=set(args.replay_names.split(",")) if args.replay_names else None,
1345+
selected_op_types=(
1346+
set(args.replay_op_types.split(",")) if args.replay_op_types else None
1347+
),
1348+
dump_folder=args.replay_folder,
1349+
)
1350+
13091351
print("-- starts side-by-side")
13101352
ratio = int(args.ratio)
13111353
data = []
@@ -1319,9 +1361,9 @@ def _size(name):
13191361
args=margs,
13201362
kwargs=mkwargs,
13211363
use_tensor=True,
1322-
gemmlinear=args.gemmlinear,
13231364
reset_names=args.reset.split(","),
13241365
exc=False,
1366+
replay_configuration=replay_configuration,
13251367
):
13261368
data.append(obs)
13271369
if (

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import sys
55
import warnings
6-
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
6+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
77
import numpy as np
88
import numpy.typing as npt
99
import onnx
@@ -15,6 +15,7 @@
1515
GraphProto,
1616
ModelProto,
1717
NodeProto,
18+
OperatorSetIdProto,
1819
TensorProto,
1920
ValueInfoProto,
2021
load as onnx_load,
@@ -1195,3 +1196,104 @@ def shadowing_names(
11951196
existing |= not_empty
11961197
created |= not_empty
11971198
return shadow, post_shadow, created
1199+
1200+
1201+
def extract_subset_of_nodes(
1202+
model: ModelProto,
1203+
name: str,
1204+
node_index: Optional[int] = None,
1205+
cut_points: Optional[Set[str]] = None,
1206+
) -> List[NodeProto]:
1207+
"""
1208+
Extracts the minimal subgraphs which can produce the output ``name``
1209+
knowing ``cut_points``.
1210+
1211+
:param model: original model
1212+
:param name: result name
1213+
:param node_index: if the node index is known, otherwise searches for it
1214+
:param cut_points: the known results or input name otherwise
1215+
:return: minimal list of nodes
1216+
"""
1217+
if node_index is None:
1218+
for i, node in enumerate(model.graph.node):
1219+
if name in node.output:
1220+
node_index = i
1221+
break
1222+
assert (
1223+
node_index is not None
1224+
and node_index < len(model.graph.node)
1225+
and name in model.graph.node[node_index].output
1226+
), f"node_index is still empty or wrong for result {name!r}"
1227+
if cut_points is None:
1228+
cut_points = {n.name for n in model.graph.input} | {
1229+
n.name for n in model.graph.initializer
1230+
}
1231+
elif model.graph.initializer:
1232+
cut_points = cut_points | {n.name for n in model.graph.initializer}
1233+
1234+
node = model.graph.node[node_index]
1235+
selected = {node_index}
1236+
current_node_index = node_index
1237+
current_input_index = 0
1238+
intermediate = {name}
1239+
inputs = set(k for k in node.input if k)
1240+
while not (inputs <= cut_points) and current_node_index >= 0:
1241+
node = model.graph.node[current_node_index]
1242+
if current_input_index == 0:
1243+
needs = [o for o in node.output if o in intermediate and o not in cut_points]
1244+
if needs:
1245+
selected.add(current_node_index)
1246+
else:
1247+
current_node_index -= 1
1248+
continue
1249+
res = node.input[current_input_index]
1250+
if res not in cut_points:
1251+
intermediate.add(res)
1252+
current_input_index += 1
1253+
if current_input_index >= len(node.input):
1254+
current_node_index -= 1
1255+
current_input_index = 0
1256+
1257+
return [model.graph.node[i] for i in sorted(selected)]
1258+
1259+
1260+
def make_submodel(
1261+
nodes: List[NodeProto],
1262+
ir_version: int,
1263+
opset_imports: List[OperatorSetIdProto],
1264+
output_names: List[str],
1265+
type_rank_fn: Callable[[str], Tuple[int, int]],
1266+
) -> ModelProto:
1267+
"""
1268+
Creates a model with the given list of nodes.
1269+
It computes the minimum list of inputs needed for this model.
1270+
The function assumes the nodes are sorted.
1271+
It does not handle yet subgraphs.
1272+
1273+
:param nodes: list of nodes
1274+
:param ir_version: ir version
1275+
:param opset_imports: opset import
1276+
:param output_names: desired outputs
1277+
:param function: function returning the type and the rank of a result
1278+
:return: model proto
1279+
"""
1280+
1281+
def _mkv_(name, itype, irank):
1282+
return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)])
1283+
1284+
not_known: Set[str] = set()
1285+
for node in nodes[::-1]:
1286+
not_known -= set(node.output)
1287+
not_known |= set(node.input)
1288+
1289+
model = oh.make_model(
1290+
oh.make_graph(
1291+
nodes,
1292+
"submodel",
1293+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
1294+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
1295+
),
1296+
ir_version=ir_version,
1297+
opset_imports=opset_imports,
1298+
)
1299+
return model

0 commit comments

Comments
 (0)