Skip to content

Commit ed4eaa6

Browse files
committed
fix
1 parent 9b63a26 commit ed4eaa6

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,9 +1233,11 @@ def _size(name):
12331233
use_tensor=True,
12341234
exc=False,
12351235
):
1236-
data.append(post_process_run_aligned_obs(obs))
1237-
df = pandas.DataFrame(data)
1238-
df.to_excel(args.output)
1236+
pobs = post_process_run_aligned_obs(obs)
1237+
data.append(pobs)
1238+
if "initializer" not in pobs and "placeholder" not in pobs:
1239+
df = pandas.DataFrame(data)
1240+
df.to_excel(args.output)
12391241
print("-- done")
12401242

12411243

onnx_diagnostic/helpers/ort_session.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,8 @@ def run_dlpack(
502502
if not v.is_contiguous():
503503
v = v.contiguous()
504504
if v.dtype == torch.bool:
505-
# It does not work with dlpack
506-
# unless onnxruntime updates the version it is using.
507-
v = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
508-
v.detach().numpy(), onnx.TensorProto.BOOL
509-
)
505+
v = v.to(torch.uint8)
506+
v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
510507
else:
511508
v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
512509
input_names.append(k)

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,10 @@ def _loop_cmp(
563563
print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")
564564

565565
placeholders = {node.name for node in ep.graph.nodes if node.op == "placeholder"}
566+
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers())}
566567
placeholders_to_state_dict = {
567-
f"p_{name.replace('.', '_')}": name for name in ep.state_dict
568+
**{f"p_{name.replace('.', '_')}": name for name in ep.state_dict},
569+
**{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()},
568570
}
569571
for n in onnx_results:
570572
if n not in placeholders:
@@ -622,7 +624,7 @@ def _loop_cmp(
622624
f"Unable to find placeholder {node.name!r} in "
623625
f"{sorted(placeholders_to_state_dict)}"
624626
)
625-
torch_results[node.name] = ep.state_dict[placeholders_to_state_dict[node.name]]
627+
torch_results[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]]
626628
if verbose:
627629
print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}")
628630
yield (

0 commit comments

Comments
 (0)