Skip to content

Commit 49ea518

Browse files
committed
a few changes
1 parent 96dfa22 commit 49ea518

File tree

2 files changed

+167
-46
lines changed

2 files changed

+167
-46
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,13 +1202,26 @@ def get_parser_sbs() -> ArgumentParser:
12021202
required=False,
12031203
help="Saves the result in an excel file every <ratio> nodes.",
12041204
)
1205+
parser.add_argument(
1206+
"--first",
1207+
action=BooleanOptionalAction,
1208+
default=False,
1209+
help="First runs the whole model.",
1210+
)
1211+
parser.add_argument(
1212+
"--gemmlinear",
1213+
action=BooleanOptionalAction,
1214+
default=False,
1215+
help="Replaces Gemm(A,X.T,B) by torch...linear(A,X,B) on onnx side",
1216+
)
1217+
12051218
return parser
12061219

12071220

12081221
def _cmd_sbs(argv: List[Any]):
12091222
import pandas
12101223
import torch
1211-
from .helpers import string_type
1224+
from .helpers import flatten_object, max_diff, string_diff, string_type
12121225
from .torch_onnx.sbs import run_aligned
12131226
from .reference import OnnxruntimeEvaluator
12141227

@@ -1257,6 +1270,23 @@ def _size(name):
12571270
ep = torch.export.load(args.ep)
12581271
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
12591272

1273+
if args.first:
1274+
print("-- compare first, run ep")
1275+
print(f"-- args: {string_type(margs, with_shape=True, with_device=True)}")
1276+
print(f"-- mkwargs: {string_type(mkwargs, with_shape=True, with_device=True)}")
1277+
expected = ep.module()(*margs, **mkwargs)
1278+
print(f"-- expected: {string_type(expected, with_shape=True, with_device=True)}")
1279+
sess = OnnxruntimeEvaluator(args.onnx, whole=True)
1280+
onx_inputs = flatten_object([margs, mkwargs], drop_keys=True)
1281+
feeds = dict(zip(sess.input_names, onx_inputs))
1282+
print(f"-- feeds: {string_type(feeds, with_shape=True, with_device=True)}")
1283+
got = sess.run(None, feeds)
1284+
print(f"-- got: {string_type(got, with_shape=True, with_device=True)}")
1285+
diff = max_diff(expected, got, hist=[0.1])
1286+
print(f"-- diff: {string_diff(diff)}")
1287+
print("-- done")
1288+
del sess
1289+
12601290
print(f"-- load onnx {args.onnx!r}")
12611291
begin = time.perf_counter()
12621292
onx = onnx.load(args.onnx)
@@ -1275,6 +1305,7 @@ def _size(name):
12751305
args=margs,
12761306
kwargs=mkwargs,
12771307
use_tensor=True,
1308+
gemmlinear=args.gemmlinear,
12781309
exc=False,
12791310
):
12801311
data.append(obs)

0 commit comments

Comments
 (0)