Skip to content

Commit c7375ca

Browse files
authored
Refactoring and fixes minor bugs in light API (#62)
* fix minour bugs in light API * refactoring * complete refactoring * fix unit test file * fix wrong import * improve shape handling * move files * fix documentation * doc
1 parent ebafa26 commit c7375ca

24 files changed

+189
-139
lines changed

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ API
99
array_api
1010
graph_api
1111
light_api
12+
translate_api
1213
npx_core_api
1314
npx_functions
1415
npx_jit_eager

_doc/api/light_api.rst

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,10 @@ start
1111

1212
.. autofunction:: onnx_array_api.light_api.start
1313

14-
translate
15-
+++++++++
16-
17-
.. autofunction:: onnx_array_api.light_api.translate
18-
19-
make_helper
20-
+++++++++++
14+
g
15+
+
2116

22-
.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended
23-
24-
.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute
17+
.. autofunction:: onnx_array_api.light_api.g
2518

2619
Classes for the Light API
2720
=========================
@@ -69,39 +62,6 @@ Vars
6962
:members:
7063
:inherited-members:
7164

72-
Classes for the Translater
73-
==========================
74-
75-
BaseEmitter
76-
+++++++++++
77-
78-
.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter
79-
:members:
80-
81-
EventType
82-
+++++++++
83-
84-
.. autoclass:: onnx_array_api.light_api.base_emitter.EventType
85-
:members:
86-
87-
InnerEmitter
88-
++++++++++++
89-
90-
.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter
91-
:members:
92-
93-
LightEmitter
94-
++++++++++++
95-
96-
.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter
97-
:members:
98-
99-
Translater
100-
++++++++++
101-
102-
.. autoclass:: onnx_array_api.light_api.translate.Translater
103-
:members:
104-
10565
Available operators
10666
===================
10767

_doc/api/translate_api.rst

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
============================
2+
onnx_array_api.translate_api
3+
============================
4+
5+
6+
Main API
7+
========
8+
9+
translate
10+
+++++++++
11+
12+
.. autofunction:: onnx_array_api.translate_api.translate
13+
14+
make_helper
15+
+++++++++++
16+
17+
.. autofunction:: onnx_array_api.translate_api.make_helper.make_node_extended
18+
19+
.. autofunction:: onnx_array_api.translate_api.make_helper.make_ref_attribute
20+
21+
Classes for the Translater
22+
==========================
23+
24+
BaseEmitter
25+
+++++++++++
26+
27+
.. autoclass:: onnx_array_api.translate_api.base_emitter.BaseEmitter
28+
:members:
29+
30+
EventType
31+
+++++++++
32+
33+
.. autoclass:: onnx_array_api.translate_api.base_emitter.EventType
34+
:members:
35+
36+
InnerEmitter
37+
++++++++++++
38+
39+
.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter
40+
:members:
41+
42+
LightEmitter
43+
++++++++++++
44+
45+
.. autoclass:: onnx_array_api.translate_api.light_emitter.LightEmitter
46+
:members:
47+
48+
Translater
49+
++++++++++
50+
51+
.. autoclass:: onnx_array_api.translate_api.translate.Translater
52+
:members:

_unittests/ut_light_api/test_backend_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from onnx.numpy_helper import from_array, to_array
2323
from onnx.backend.base import Device, DeviceType
2424
from onnx_array_api.reference import ExtendedReferenceEvaluator
25-
from onnx_array_api.light_api.make_helper import make_node_extended
26-
from onnx_array_api.light_api import translate
25+
from onnx_array_api.translate_api.make_helper import make_node_extended
26+
from onnx_array_api.translate_api import translate
2727
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
2828

2929
verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0

_unittests/ut_light_api/test_light_api.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_neg(self):
211211
self.assertIsInstance(v, Var)
212212
self.assertEqual(["X"], v.parent.input_names)
213213
s = str(v)
214-
self.assertEqual("X:FLOAT", s)
214+
self.assertEqual("X:FLOAT:[]", s)
215215
onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
216216
self.assertIsInstance(onx, ModelProto)
217217
ref = ReferenceEvaluator(onx)
@@ -510,7 +510,23 @@ def ah(self):
510510
expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1))
511511
self.assertEqualArray(expected, got)
512512

513+
def test_input_shape(self):
514+
kernel = (np.arange(9) + 1).reshape(3, 3).astype(np.float32)
515+
model = (
516+
start()
517+
.vin("X", shape=[None, None])
518+
.cst(kernel[np.newaxis, np.newaxis, ...])
519+
.rename("W")
520+
.bring("X", "W")
521+
.Conv(pads=[1, 1, 1, 1])
522+
.rename("Y")
523+
.vout(shape=[])
524+
.to_onnx()
525+
)
526+
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
527+
self.assertNotIn("shape{}", i)
528+
513529

514530
if __name__ == "__main__":
515-
TestLightApi().test_domain()
531+
TestLightApi().test_add()
516532
unittest.main(verbosity=2)

_unittests/ut_light_api/test_translate.py renamed to _unittests/ut_translate_api/test_translate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from onnx.defs import onnx_opset_version
66
from onnx.reference import ReferenceEvaluator
77
from onnx_array_api.ext_test_case import ExtTestCase
8-
from onnx_array_api.light_api import start, translate, g
9-
from onnx_array_api.light_api.base_emitter import EventType
8+
from onnx_array_api.light_api import start, g
9+
from onnx_array_api.translate_api import translate
10+
from onnx_array_api.translate_api.base_emitter import EventType
1011

1112
OPSET_API = min(19, onnx_opset_version() - 1)
1213

_unittests/ut_light_api/test_translate_classic.py renamed to _unittests/ut_translate_api/test_translate_classic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
)
1616
from onnx.checker import check_model
1717
from onnx_array_api.ext_test_case import ExtTestCase
18-
from onnx_array_api.light_api import start, translate
18+
from onnx_array_api.light_api import start
19+
from onnx_array_api.translate_api import translate
1920

2021
OPSET_API = min(19, onnx_opset_version() - 1)
2122

@@ -335,7 +336,7 @@ def _run(cls, code):
335336
import onnx
336337
import onnx.helper
337338
import onnx.numpy_helper
338-
import onnx_array_api.light_api.make_helper
339+
import onnx_array_api.translate_api.make_helper
339340
import onnx.reference.custom_element_types
340341

341342
def from_array_extended(tensor, name=None):
@@ -362,7 +363,7 @@ def from_array_extended(tensor, name=None):
362363
globs = onnx.__dict__.copy()
363364
globs.update(onnx.helper.__dict__)
364365
globs.update(onnx.numpy_helper.__dict__)
365-
globs.update(onnx_array_api.light_api.make_helper.__dict__)
366+
globs.update(onnx_array_api.translate_api.make_helper.__dict__)
366367
globs.update(onnx.reference.custom_element_types.__dict__)
367368
globs["from_array_extended"] = from_array_extended
368369
locs = {}

onnx_array_api/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_parser_translate() -> ArgumentParser:
5656

5757

5858
def _cmd_translate(argv: List[Any]):
59-
from .light_api import translate
59+
from .translate_api import translate
6060

6161
parser = get_parser_translate()
6262
args = parser.parse_args(argv[1:])

0 commit comments

Comments
 (0)