Skip to content

Commit 6d071bd

Browse files
malfetpytorchmergebot
authored andcommitted
Remove numpy dependency from onnx (pytorch#159177)
One should not expect numpy to be there during onnx import Forward fix for : pytorch#157734 Added regression test to `test_without_numpy` function Test plan: Run `python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch; import torch.onnx"` with/without this fix Pull Request resolved: pytorch#159177 Approved by: https://github.com/atalman, https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/cyyever, https://github.com/Skylion007, https://github.com/andrewboldi
1 parent d742a28 commit 6d071bd

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

.ci/pytorch/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,8 @@ test_without_numpy() {
974974
if [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
975975
python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch;torch.compile(lambda x:print(x))('Hello World')"
976976
fi
977+
# Regression test for https://github.com/pytorch/pytorch/pull/157734 (torch.onnx should be importable without numpy)
978+
python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch; import torch.onnx"
977979
popd
978980
}
979981

torch/onnx/_internal/exporter/_onnx_program.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import warnings
1616
from typing import Any, Callable, TYPE_CHECKING
1717

18-
import numpy as np
19-
2018
import torch
2119
from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir
2220
from torch.onnx._internal.exporter import _dynamic_shapes, _ir_passes
@@ -121,6 +119,7 @@ def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
121119

122120
def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValue:
123121
"""Convert a PyTorch tensor to an ONNX Runtime OrtValue."""
122+
import numpy as np
124123
import onnxruntime as ort
125124

126125
from torch.onnx._internal.exporter import _core

0 commit comments

Comments
 (0)