Skip to content

Commit 1da11e2

Browse files
committed
fix
1 parent 8c93c25 commit 1da11e2

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,23 @@ def _duplicated_values(d):
113113
return final
114114

115115

116+
def _validation_nn_functional(
117+
node: onnx.NodeProto, new_feeds: Dict[str, torch.Tensor], expected: List[torch.Tensor]
118+
) -> Optional[str]:
119+
if node.op_type == "Gemm" and len(node.input) == 3:
120+
atts = {}
121+
for att in node.attribute:
122+
if att.name in ("alpha", "beta"):
123+
atts[att.name] = att.f
124+
elif att.name in ("transA", "transB"):
125+
atts[att.name] = att.i
126+
if atts == {"transB": 1}:
127+
res = torch.nn.functional.linear(*[new_feeds[i] for i in node.input])
128+
diff = max_diff(res, expected[0])
129+
return f"function.linear:{string_diff(diff)}"
130+
return None
131+
132+
116133
def _loop_onnx_node(
117134
onx: onnx.ModelProto,
118135
ep_graph_nodes: List[torch.fx.Node],
@@ -201,6 +218,7 @@ def _loop_onnx_node(
201218
f"node is {pretty_onnx(node)}"
202219
)
203220

221+
comment = None
204222
cross = None
205223
if run_onnx_with_torch_inputs:
206224
# Let's run the operator with torch results if they are available
@@ -223,8 +241,13 @@ def _loop_onnx_node(
223241
cross = ref.run(None, new_feeds)
224242
if verbose > 1:
225243
print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}")
244+
# Gemm = torch.nn.function.linear, in that case, we just run it as well
245+
to = mapping_onnx_to_torch.get(node.output[0], node.output[0])
246+
if to in torch_results:
247+
comment = _validation_nn_functional(node, new_feeds, [torch_results[to]])
226248
elif verbose > 1:
227249
print(f"[run_aligned] second run not possible because of missing {removed}")
250+
228251
if cross is None:
229252
cross = [None for _ in res]
230253

@@ -264,6 +287,7 @@ def _loop_onnx_node(
264287
status.yielded_nodes += 1
265288
if tmp.err_abs is not None:
266289
status.update(tmp.err_abs)
290+
tmp.comment = comment
267291
yield tmp
268292

269293
# do we need to dump pieces if graph the user can replay?
@@ -1055,6 +1079,7 @@ def forward(self, x):
10551079
if r:
10561080
yield r.check(already_yielded)
10571081

1058-
loop.close()
1082+
if loop is not None:
1083+
loop.close()
10591084
if verbose:
10601085
print(f"[run_aligned] done with status={status.to_str()}")

onnx_diagnostic/torch_onnx/sbs_dataclasses.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,12 @@ class RunAlignedRecord:
332332
:param ep_time_run: execution time for the exported program
333333
:param onnx_time_run: execution time for the onnx model, that includes
334334
the creation of the onnx model so that's probably not very usable
335+
:param err_abs2: same as `err_abs` if onnx kernel is run with torch results
336+
:param err_rel2: same as `err_rel` if onnx kernel is run with torch results
337+
:param err_dev2: same as `err_dev` if onnx kernel is run with torch results
338+
:param err_nan2: same as `err_nan` if onnx kernel is run with torch results
339+
:param err_h012: same as `err_h01` if onnx kernel is run with torch results
340+
:param comment: any additional information
335341
"""
336342

337343
ep_id_node: Optional[int] = None
@@ -355,6 +361,7 @@ class RunAlignedRecord:
355361
err_dev2: Optional[float] = None
356362
err_nan2: Optional[float] = None
357363
err_h012: Optional[float] = None
364+
comment: Optional[str] = None
358365

359366
def __post_init__(self):
360367
"Validation."

0 commit comments

Comments
 (0)