diff --git a/_unittests/ut_torch_models/test_test_helpers.py b/_unittests/ut_torch_models/test_test_helpers.py index 90dc1b52..dc8b5ac3 100644 --- a/_unittests/ut_torch_models/test_test_helpers.py +++ b/_unittests/ut_torch_models/test_test_helpers.py @@ -8,6 +8,7 @@ ignore_warnings, requires_torch, requires_experimental, + requires_onnxscript, ) from onnx_diagnostic.torch_models.test_helper import ( get_inputs_for_task, @@ -88,6 +89,7 @@ def test_validate_model_onnx_dynamo_ir(self): ) @requires_torch("2.7") + @requires_onnxscript("0.4") @hide_stdout() @ignore_warnings(FutureWarning) def test_validate_model_onnx_dynamo_os_ort(self): diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index a2b443c5..5d1d472d 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -504,6 +504,7 @@ def validate_model( print(f"[validate_model] -- dumps exported program in {dump_folder!r}...") with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f: f.write(str(ep)) + torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2")) with open(os.path.join(dump_folder, f"{folder_name}.graph"), "w") as f: f.write(str(ep.graph)) if verbose: diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index aafc431e..a43fd05d 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -296,7 +296,7 @@ def post_process(obs): ) for inp, v in zip(onx.graph.input, args): - onnx_results[inp.name] = v.numpy() + onnx_results[inp.name] = v.cpu().numpy() if verbose: print( f"[run_aligned] +onnx-input: {inp.name}: "