Skip to content

Commit 7e19bf5

Browse files
authored
fix import issues with the latest onnx (#103)
* empty set * style * fix code * fix
1 parent 01bd2dc commit 7e19bf5

File tree

10 files changed

+32
-13
lines changed

10 files changed

+32
-13
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.2
55
+++++
66

7+
* :pr:`103`: fix import issue with the latest onnx version
78
* :pr:`101`: fix as_tensor in onnx_text_plot_tree
89

910
0.3.1

_unittests/ut_npx/test_npx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,7 +1268,7 @@ def test_numpy_op_bin_reduce(self):
12681268
"xor", lambda x, y: (x.sum() == y.sum()) ^ (((-x).sum()) == y.sum())
12691269
)
12701270

1271-
def common_test_inline(self, fonx, fnp, tcst=0):
1271+
def common_test_inline(self, fonx, fnp, tcst=0, atol=1e-10):
12721272
f = fonx(Input("A"))
12731273
self.assertIsInstance(f, Var)
12741274
onx = f.to_onnx(constraints={0: Float64[None], (0, False): Float64[None]})
@@ -1277,7 +1277,7 @@ def common_test_inline(self, fonx, fnp, tcst=0):
12771277
y = fnp(x)
12781278
ref = ReferenceEvaluator(onx)
12791279
got = ref.run(None, {"A": x})
1280-
self.assertEqualArray(y, got[0], atol=1e-10)
1280+
self.assertEqualArray(y, got[0], atol=atol)
12811281

12821282
def common_test_inline_bin(self, fonx, fnp, tcst=0):
12831283
f = fonx(Input("A"), Input("B"))
@@ -1470,7 +1470,7 @@ def test_equal(self):
14701470

14711471
@unittest.skipIf(scipy is None, reason="scipy is not installed.")
14721472
def test_erf(self):
1473-
self.common_test_inline(erf_inline, scipy.special.erf)
1473+
self.common_test_inline(erf_inline, scipy.special.erf, atol=1e-7)
14741474

14751475
def test_exp(self):
14761476
self.common_test_inline(exp_inline, np.exp)

_unittests/ut_translate_api/test_translate_classic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _run(cls, code):
406406
import onnx.helper
407407
import onnx.numpy_helper
408408
import onnx_array_api.translate_api.make_helper
409-
import onnx.reference.custom_element_types
409+
import ml_dtypes
410410

411411
def from_array_extended(tensor, name=None):
412412
dt = tensor.dtype
@@ -433,7 +433,7 @@ def from_array_extended(tensor, name=None):
433433
globs.update(onnx.helper.__dict__)
434434
globs.update(onnx.numpy_helper.__dict__)
435435
globs.update(onnx_array_api.translate_api.make_helper.__dict__)
436-
globs.update(onnx.reference.custom_element_types.__dict__)
436+
globs.update(ml_dtypes.__dict__)
437437
globs["from_array_extended"] = from_array_extended
438438
locs = {}
439439
try:

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
4040
cmds = [sys.executable, "-u", os.path.join(fold, name)]
4141
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
4242
res = p.communicate()
43-
out, err = res
43+
_out, err = res
4444
st = err.decode("ascii", errors="ignore")
4545
if st and "Traceback" in st:
4646
if '"dot" not found in path.' in st:

onnx_array_api/profiling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,12 @@ def align_text(text, size):
281281
return text[:h] + "..." + text[-h + 1 :]
282282

283283
dicts = self.as_dict(filter_node=filter_node, sort_key=sort_key)
284-
max_nc = max(max(_["nc1"] for _ in dicts), max(_["nc2"] for _ in dicts))
284+
set1 = [_["nc1"] for _ in dicts]
285+
set2 = [_["nc1"] for _ in dicts]
286+
if set1 or set2:
287+
max_nc = max([*set1, *set2])
288+
else:
289+
max_nc = 1
285290
dg = int(math.log(max_nc) / math.log(10) + 1.5)
286291
line_format = (
287292
"{indent}{fct} -- {nc1: %dd} {nc2: %dd} -- {tin:1.5f} {tall:1.5f}"

onnx_array_api/translate_api/base_emitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
129129
if value[0].type == AttributeProto.TENSOR:
130130
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
131131
sdtype = repl.get(str(v.dtype), str(str(v.dtype)))
132+
package = "np" if hasattr(np, sdtype) else "ml_dtypes"
132133
return [], (
133-
f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), "
134+
f"from_array(np.array({v.tolist()}, dtype={package}.{sdtype}), "
134135
f"name={value[0].name!r})"
135136
)
136137
if isinstance(v, (int, float, list)):

onnx_array_api/translate_api/builder_emitter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Dict, List
2+
import numpy as np
23
from onnx import TensorProto
34
from onnx.numpy_helper import to_array
45
from .base_emitter import BaseEmitter
@@ -135,7 +136,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
135136
val = to_array(init)
136137
stype = str(val.dtype).split(".")[-1]
137138
name = self._clean_result_name(init.name)
138-
rows.append(f" {name} = np.array({val.tolist()}, dtype=np.{stype})")
139+
package = "np" if hasattr(np, stype) else "ml_dtypes"
140+
rows.append(
141+
f" {name} = np.array({val.tolist()}, dtype={package}.{stype})"
142+
)
139143
return rows
140144

141145
def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:

onnx_array_api/translate_api/inner_emitter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Dict, List, Optional, Tuple
2+
import numpy as np
23
from onnx import AttributeProto
34
from ..annotations import ELEMENT_TYPE_NAME
45
from .base_emitter import BaseEmitter
@@ -105,7 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
105106
else:
106107
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
107108
else:
108-
sdtype = f"np.{sdtype}"
109+
sdtype = f"np.{sdtype}" if hasattr(np, sdtype) else f"ml_dtypes.{sdtype}"
109110

110111
return [
111112
"initializers.append(",
@@ -233,7 +234,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
233234
else:
234235
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
235236
else:
236-
sdtype = f"np.{sdtype}"
237+
sdtype = f"np.{sdtype}" if hasattr(np, sdtype) else f"ml_dtypes.{sdtype}"
237238
if value.size <= 16:
238239
return [
239240
"initializers.append(",

onnx_array_api/translate_api/light_emitter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Dict, List
2+
import numpy as np
23
from ..annotations import ELEMENT_TYPE_NAME
34
from .base_emitter import BaseEmitter
45

@@ -43,8 +44,9 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
4344
value = kwargs["value"]
4445
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
4546
sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
47+
package = "np" if hasattr(np, sdtype) else "ml_dtypes"
4648
return [
47-
f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
49+
f"cst(np.array({value.tolist()}, dtype={package}.{sdtype}))",
4850
f"rename({name!r})",
4951
]
5052

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ select = [
3939
]
4040

4141
[tool.ruff.lint.per-file-ignores]
42-
"**" = ["B905", "C401", "C408", "C413", "PYI041", "RUF012", "RUF100", "RUF010", "SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103", "UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038"]
42+
"**" = [
43+
"B905", "C401", "C408", "C413", "PYI041",
44+
"RUF012", "RUF100", "RUF010",
45+
"SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103",
46+
"UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038", "UP045"
47+
]
4348
"**/plot*.py" = ["B018"]
4449
"_doc/examples/plot_first_example.py" = ["E402", "F811"]
4550
"_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]

0 commit comments

Comments
 (0)