Skip to content

Commit 2ceb728

Browse files
committed
save ep
1 parent 8c7302d commit 2ceb728

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def validate_model(
504504
print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
505505
with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
506506
f.write(str(ep))
507+
torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
507508
with open(os.path.join(dump_folder, f"{folder_name}.graph"), "w") as f:
508509
f.write(str(ep.graph))
509510
if verbose:

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def post_process(obs):
296296
)
297297

298298
for inp, v in zip(onx.graph.input, args):
299-
onnx_results[inp.name] = v.numpy()
299+
onnx_results[inp.name] = v.cpu().numpy()
300300
if verbose:
301301
print(
302302
f"[run_aligned] +onnx-input: {inp.name}: "

0 commit comments

Comments
 (0)