Skip to content

Commit b19f2fd

Browse files
committed
fix issues
1 parent fd5fc82 commit b19f2fd

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

_unittests/ut_light_api/test_backend_export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
make_opsetid,
2020
make_tensor_value_info,
2121
)
22-
from onnx.reference.op_run import to_array_extended
22+
23+
try:
24+
from onnx.reference.op_run import to_array_extended
25+
except ImportError:
26+
from onnx.numpy_helper import to_array as to_array_extended
2327
from onnx.numpy_helper import from_array, to_array
2428
from onnx.backend.base import Device, DeviceType
2529
from onnx_array_api.reference import ExtendedReferenceEvaluator

onnx_array_api/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def add_rows(rows, d):
438438
if verbose and fLOG is not None:
439439
fLOG(
440440
"[pstats] %s=%r"
441-
% ((clean_text(k[0].replace("\\", "/")),) + k[1:], v)
441+
% (clean_text(k[0].replace("\\", "/"), *k[1:]), v)
442442
)
443443
if len(v) < 5:
444444
continue

onnx_array_api/reference/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
import numpy as np
33
from onnx import TensorProto
44
from onnx.numpy_helper import from_array as onnx_from_array
5-
from onnx.reference.ops.op_cast import (
6-
bfloat16,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
5+
6+
try:
7+
from onnx.reference.ops.op_cast import (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
except ImportError:
15+
bfloat16 = None
1216
from onnx.reference.op_run import to_array_extended
1317
from .evaluator import ExtendedReferenceEvaluator
1418
from .evaluator_yield import (
@@ -28,6 +32,8 @@ def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorP
2832
:param name: name
2933
:return: TensorProto
3034
"""
35+
if bfloat16 is None:
36+
return onnx_from_array(tensor, name)
3137
dt = tensor.dtype
3238
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
3339
to = TensorProto.FLOAT8E4M3FN

onnx_array_api/reference/ops/op_cast_like.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from onnx.helper import np_dtype_to_tensor_dtype
22
from onnx.onnx_pb import TensorProto
33
from onnx.reference.op_run import OpRun
4-
from onnx.reference.ops.op_cast import (
5-
bfloat16,
6-
cast_to,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
4+
from onnx.reference.ops.op_cast import cast_to
5+
6+
try:
7+
from onnx.reference.ops.op_cast import (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
except ImportError:
15+
bfloat16 = None
1216

1317

1418
def _cast_like(x, y, saturate):
19+
if bfloat16 is None:
20+
return (cast_to(x, y.dtype, saturate),)
1521
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
1622
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
1723
to = TensorProto.BFLOAT16

0 commit comments

Comments
 (0)