Skip to content

Commit 6a0f91f

Browse files
committed
fix
1 parent ed4eaa6 commit 6a0f91f

File tree

1 file changed

+24
-4
lines changed
  • onnx_diagnostic/torch_onnx

1 file changed

+24
-4
lines changed

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,21 @@ def run_fx_node(
108108
return args
109109
if node.op == "call_function":
110110
assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
111+
for a, ea in zip(args, node.args):
112+
if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
113+
ta = ea.meta["val"]
114+
# if not isinstance(ta, torch.Tensor):
115+
# print("******", args)
116+
# print("******", node.args)
117+
# print("******", node.kwargs)
118+
# print("******", node.meta)
119+
# print(ta)
120+
assert len(a.shape) == len(ta.shape) and a.dtype == ta.dtype, (
121+
f"Unable to run node {node!r}, target={node.target!r}, "
122+
f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
123+
f"args={string_type(args, with_shape=True, with_device=True)}, "
124+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
125+
)
111126
try:
112127
outputs = node.target(*args, **(kwargs or {}))
113128
except RuntimeError as e:
@@ -381,7 +396,7 @@ def forward(self, x):
381396
-v 1 --atol=0.1 --rtol=1
382397
"""
383398
assert callable(run_cls), f"run_cls={run_cls} not a callable"
384-
str_kws = dict(with_shape=True, with_device=True, with_min_max=True)
399+
str_kws = dict(with_shape=True, with_device=True)
385400
has_cuda = any(
386401
(isinstance(t, torch.Tensor) and t.is_cuda)
387402
for t in flatten_object([args, kwargs], drop_keys=True)
@@ -592,7 +607,12 @@ def _loop_cmp(
592607
print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}")
593608

594609
if node.op == "placeholder":
595-
if node.name in onnx_results:
610+
is_input = node.name in placeholders
611+
if node.name in onnx_results and (
612+
is_input
613+
or ep_state_dict[placeholders_to_state_dict[node.name]].shape
614+
== onnx_results[node.name]
615+
):
596616
torch_results[node.name] = (
597617
onnx_results[node.name]
598618
if use_tensor
@@ -602,7 +622,6 @@ def _loop_cmp(
602622
t = torch_results[node.name]
603623
print(f"[run_aligned-ep] =plh: {node.name}={string_type(t, **str_kws)}")
604624
# Otherwise, it is an input.
605-
is_input = node.name in placeholders
606625
yield (
607626
-1,
608627
-1,
@@ -615,7 +634,8 @@ def _loop_cmp(
615634
{}
616635
if is_input
617636
else max_diff(
618-
placeholders_to_state_dict[node.name], onnx_results[node.name]
637+
ep_state_dict[placeholders_to_state_dict[node.name]],
638+
onnx_results[node.name],
619639
)
620640
),
621641
)

0 commit comments

Comments
 (0)