Skip to content

Commit 940bae5

Browse files
committed
onx fix
1 parent f0c543c commit 940bae5

File tree

1 file changed

+6
-8
lines changed
  • onnx_diagnostic/torch_onnx

1 file changed

+6
-8
lines changed

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,11 @@ def run_fx_node(
111111
for a, ea in zip(args, node.args):
112112
if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
113113
ta = ea.meta["val"]
114-
# if not isinstance(ta, torch.Tensor):
115-
# print("******", args)
116-
# print("******", node.args)
117-
# print("******", node.kwargs)
118-
# print("******", node.meta)
119-
# print(ta)
120-
assert len(a.shape) == len(ta.shape) and a.dtype == ta.dtype, (
114+
assert (
115+
isinstance(ta, torch.Tensor)
116+
and len(a.shape) == len(ta.shape)
117+
and a.dtype == ta.dtype
118+
), (
121119
f"Unable to run node {node!r}, target={node.target!r}, "
122120
f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
123121
f"args={string_type(args, with_shape=True, with_device=True)}, "
@@ -672,7 +670,7 @@ def _loop_cmp(
672670
outputs = [node.name] if isinstance(node.name, str) else list(node.name)
673671
args, kwargs = prepare_args_kwargs(torch_results, node)
674672
new_outputs = run_fx_node(node, args, kwargs)
675-
if isinstance(new_outputs, (torch.Tensor, int, float, list)):
673+
if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)):
676674
new_outputs = (new_outputs,)
677675

678676
if new_outputs is None:

0 commit comments

Comments
 (0)