Skip to content

Commit 0471c35

Browse files
committed
style
1 parent fdf6d63 commit 0471c35

File tree

20 files changed

+84
-192
lines changed

20 files changed

+84
-192
lines changed

_doc/examples/plot_onnx_diff.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from onnx_array_api.reference import compare_onnx_execution
2020
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
2121

22-
2322
data = load_iris()
2423
X_train, X_test = train_test_split(data.data)
2524
model = GaussianMixture()

_doc/examples/plot_optimization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from onnx_array_api.ext_test_case import measure_time
3030
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
3131

32-
3332
filename = example_path("data/small.onnx")
3433
optimized = filename + ".optimized.onnx"
3534

_doc/examples/plot_profiling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile
2626
from onnx_array_api.plotting.stat_plot import plot_ort_profile
2727

28-
2928
suffix = ""
3029
filename = example_path(f"data/small{suffix}.onnx")
3130
optimized = filename + ".optimized.onnx"

_unittests/ut_graph_api/test_graph_builder.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,44 +18,38 @@ def call_optimizer(self, onx):
1818
return gr.to_onnx()
1919

2020
def test_remove_unused_nodes(self):
21-
model = onnx.parser.parse_model(
22-
"""
21+
model = onnx.parser.parse_model("""
2322
<ir_version: 8, opset_import: [ "": 18]>
2423
agraph (float[N] x) => (float[N] z) {
2524
two = Constant <value_float=2.0> ()
2625
four = Add(two, two)
2726
z = Mul(x, x)
28-
}"""
29-
)
27+
}""")
3028
onx = self.call_optimizer(model)
3129
self.assertEqual(len(onx.graph.node), 1)
3230
self.assertEqual(onx.graph.node[0].op_type, "Mul")
3331

3432
def test_initializers(self):
35-
model = onnx.parser.parse_model(
36-
"""
33+
model = onnx.parser.parse_model("""
3734
<ir_version: 8, opset_import: [ "": 18]>
3835
agraph (float[N] x) => (float[N] z)
3936
<float two = {2.0}> {
4037
four = Add(two, two)
4138
z = Mul(x, x)
42-
}"""
43-
)
39+
}""")
4440
self.assertEqual(len(model.graph.initializer), 1)
4541
onx = self.call_optimizer(model)
4642
self.assertEqual(len(onx.graph.node), 1)
4743
self.assertEqual(onx.graph.node[0].op_type, "Mul")
4844
self.assertEqual(len(onx.graph.initializer), 0)
4945

5046
def test_keep_unused_outputs(self):
51-
model = onnx.parser.parse_model(
52-
"""
47+
model = onnx.parser.parse_model("""
5348
<ir_version: 8, opset_import: [ "": 18]>
5449
agraph (float[N] x) => (float[M] z) {
5550
w1, w2, w3 = Split (x)
5651
z = Mul(w3, w3)
57-
}"""
58-
)
52+
}""")
5953
onx = self.call_optimizer(model)
6054
self.assertEqual(len(onx.graph.node), 2)
6155
self.assertEqual(onx.graph.node[0].op_type, "Split")
@@ -381,30 +375,26 @@ def test_make_nodes_noprefix(self):
381375
self.assertEqualArray(expected, got[0])
382376

383377
def test_node_pattern(self):
384-
model = onnx.parser.parse_model(
385-
"""
378+
model = onnx.parser.parse_model("""
386379
<ir_version: 8, opset_import: [ "": 18]>
387380
agraph (float[N] x) => (float[N] z) {
388381
two = Constant <value_float=2.0> ()
389382
four = Add(two, two)
390383
z = Mul(x, four)
391-
}"""
392-
)
384+
}""")
393385
gr = GraphBuilder(model)
394386
p = gr.np(index=0)
395387
r = repr(p)
396388
self.assertEqual("NodePattern(index=0, op_type=None, name=None)", r)
397389

398390
def test_update_node_attribute(self):
399-
model = onnx.parser.parse_model(
400-
"""
391+
model = onnx.parser.parse_model("""
401392
<ir_version: 8, opset_import: [ "": 18]>
402393
agraph (float[N] x) => (float[N] z) {
403394
two = Constant <value_float=2.0> ()
404395
four = Add(two, two)
405396
z = Mul(x, four)
406-
}"""
407-
)
397+
}""")
408398
gr = GraphBuilder(model)
409399
self.assertEqual(len(gr.nodes), 3)
410400
m = gr.update_attribute(gr.np(op_type="Constant"), value_float=float(1))
@@ -416,15 +406,13 @@ def test_update_node_attribute(self):
416406
self.assertIn("f: 1", str(node))
417407

418408
def test_delete_node_attribute(self):
419-
model = onnx.parser.parse_model(
420-
"""
409+
model = onnx.parser.parse_model("""
421410
<ir_version: 8, opset_import: [ "": 18]>
422411
agraph (float[N] x) => (float[N] z) {
423412
two = Constant <value_float=2.0> ()
424413
four = Add(two, two)
425414
z = Mul(x, four)
426-
}"""
427-
)
415+
}""")
428416
gr = GraphBuilder(model)
429417
self.assertEqual(len(gr.nodes), 3)
430418
m = gr.update_attribute(

_unittests/ut_npx/test_sklearn_array_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
88
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
99

10-
1110
DEFAULT_OPSET = onnx_opset_version()
1211

1312

_unittests/ut_ort/test_ort_optimizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from onnx_array_api.ext_test_case import ExtTestCase
77
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
88

9-
109
DEFAULT_OPSET = onnx_opset_version()
1110

1211

_unittests/ut_ort/test_sklearn_array_api_ort.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
88
from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor
99

10-
1110
DEFAULT_OPSET = onnx_opset_version()
1211

1312

_unittests/ut_plotting/test_graphviz.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
class TestGraphviz(ExtTestCase):
1414
@classmethod
1515
def _get_graph(cls):
16-
return onnx.parser.parse_model(
17-
"""
16+
return onnx.parser.parse_model("""
1817
<ir_version: 8, opset_import: [ "": 18]>
1918
agraph (float[N] x) => (float[N] z) {
2019
two = Constant <value_float=2.0> ()
2120
four = Add(two, two)
2221
z = Mul(x, x)
23-
}"""
24-
)
22+
}""")
2523

2624
@skipif_ci_windows("graphviz not installed")
2725
@skipif_ci_apple("graphviz not installed")

_unittests/ut_plotting/test_text_plot.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def test_onnx_text_plot_tree_cls_2(self):
6969
model_def = load(f)
7070
res = onnx_text_plot_tree(model_def.graph.node[0])
7171
self.assertIn("n_classes=3", res)
72-
expected = textwrap.dedent(
73-
"""
72+
expected = textwrap.dedent("""
7473
n_classes=3
7574
n_trees=1
7675
----
@@ -92,8 +91,7 @@ def test_onnx_text_plot_tree_cls_2(self):
9291
-f 0:0 1:0 2:1
9392
+f 0:0 1:1 2:0
9493
+f 0:1 1:0 2:0
95-
"""
96-
).strip(" \n\r")
94+
""").strip(" \n\r")
9795
res = res.replace("np.float32(", "").replace(")", "")
9896
self.assertEqual(expected, res.strip(" \n\r"))
9997

@@ -104,39 +102,33 @@ def test_onnx_simple_text_plot_kmeans(self):
104102
model.fit(x)
105103
onx = to_onnx(model, x.astype(numpy.float32), target_opset=15)
106104
text = onnx_simple_text_plot(onx)
107-
expected1 = textwrap.dedent(
108-
"""
105+
expected1 = textwrap.dedent("""
109106
ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0
110107
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
111108
Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0
112109
Add(Re_reduced0, Ge_Y0) -> Ad_C01
113110
Add(Ad_Addcst, Ad_C01) -> Ad_C0
114111
Sqrt(Ad_C0) -> scores
115112
ArgMin(Ad_C0, axis=1, keepdims=0) -> label
116-
"""
117-
).strip(" \n")
118-
expected2 = textwrap.dedent(
119-
"""
113+
""").strip(" \n")
114+
expected2 = textwrap.dedent("""
120115
ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0
121116
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
122117
Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0
123118
Add(Re_reduced0, Ge_Y0) -> Ad_C01
124119
Add(Ad_Addcst, Ad_C01) -> Ad_C0
125120
Sqrt(Ad_C0) -> scores
126121
ArgMin(Ad_C0, axis=1, keepdims=0) -> label
127-
"""
128-
).strip(" \n")
129-
expected3 = textwrap.dedent(
130-
"""
122+
""").strip(" \n")
123+
expected3 = textwrap.dedent("""
131124
ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0
132125
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
133126
Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0
134127
Add(Re_reduced0, Ge_Y0) -> Ad_C01
135128
Add(Ad_Addcst, Ad_C01) -> Ad_C0
136129
ArgMin(Ad_C0, axis=1, keepdims=0) -> label
137130
Sqrt(Ad_C0) -> scores
138-
"""
139-
).strip(" \n")
131+
""").strip(" \n")
140132
if expected1 not in text and expected2 not in text and expected3 not in text:
141133
raise AssertionError(f"Unexpected value:\n{text}")
142134

@@ -165,17 +157,15 @@ def test_onnx_simple_text_plot_toy(self):
165157
{"X": x.astype(numpy.float32)}, outputs={"Y": x}, target_opset=15
166158
)
167159
text = onnx_simple_text_plot(onx, verbose=False)
168-
expected = textwrap.dedent(
169-
"""
160+
expected = textwrap.dedent("""
170161
Add(X, Ad_Addcst) -> Ad_C0
171162
Abs(Ad_C0) -> Ab_Y0
172163
Identity(Ad_Addcst) -> Su_Subcst
173164
Sub(X, Su_Subcst) -> Su_C0
174165
Abs(Su_C0) -> Ab_Y02
175166
Div(Ab_Y0, Ab_Y02) -> Di_C0
176167
Abs(Di_C0) -> Y
177-
"""
178-
).strip(" \n")
168+
""").strip(" \n")
179169
self.assertIn(expected, text)
180170
text2, out, err = self.capture(lambda: onnx_simple_text_plot(onx, verbose=True))
181171
self.assertEqual(text, text2)
@@ -188,11 +178,9 @@ def test_onnx_simple_text_plot_leaky(self):
188178
{"X": FloatTensorType()}, outputs={"Y": FloatTensorType()}, target_opset=15
189179
)
190180
text = onnx_simple_text_plot(onx)
191-
expected = textwrap.dedent(
192-
"""
181+
expected = textwrap.dedent("""
193182
LeakyRelu(X, alpha=0.50) -> Y
194-
"""
195-
).strip(" \n")
183+
""").strip(" \n")
196184
self.assertIn(expected, text)
197185

198186
def test_onnx_text_plot_io(self):
@@ -201,11 +189,9 @@ def test_onnx_text_plot_io(self):
201189
{"X": FloatTensorType()}, outputs={"Y": FloatTensorType()}, target_opset=15
202190
)
203191
text = onnx_text_plot_io(onx)
204-
expected = textwrap.dedent(
205-
"""
192+
expected = textwrap.dedent("""
206193
input:
207-
"""
208-
).strip(" \n")
194+
""").strip(" \n")
209195
self.assertIn(expected, text)
210196

211197
def test_onnx_simple_text_plot_if(self):
@@ -244,11 +230,9 @@ def test_onnx_simple_text_plot_if(self):
244230
{"x1": x1, "x2": x2}, target_opset=opv, outputs=[("y", FloatTensorType())]
245231
)
246232
text = onnx_simple_text_plot(model_def)
247-
expected = textwrap.dedent(
248-
"""
233+
expected = textwrap.dedent("""
249234
input:
250-
"""
251-
).strip(" \n")
235+
""").strip(" \n")
252236
self.assertIn(expected, text)
253237
self.assertIn("If(Gr_C0, else_branch=G1, then_branch=G2)", text)
254238

_unittests/ut_reference/test_evaluator_yield.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -431,55 +431,43 @@ def test_distance_sequence_str(self):
431431
005=|RESULTfloat322:2x2CEIOLinearRegressioY1|RESULTfloat322:2x2CEIOLinearRegressioY1
432432
006~|RESULTfloat322:2x2CEIOAbsY|RESULTfloat322:2x3CEIPAbsZ
433433
007~|OUTPUTfloat322:2x2CEIOY|OUTPUTfloat322:2x2CEIPY
434-
""".replace(
435-
" ", ""
436-
).strip(
437-
"\n "
438-
)
434+
""".replace(" ", "").strip("\n ")
439435
self.maxDiff = None
440436
self.assertEqual(expected, text.replace(" ", "").strip("\n"))
441437

442438
def test_compare_execution(self):
443-
m1 = parse_model(
444-
"""
439+
m1 = parse_model("""
445440
<ir_version: 8, opset_import: [ "": 18]>
446441
agraph (float[N] x) => (float[N] z) {
447442
two = Constant <value_float=2.0> ()
448443
four = Add(two, two)
449444
z = Mul(x, x)
450-
}"""
451-
)
452-
m2 = parse_model(
453-
"""
445+
}""")
446+
m2 = parse_model("""
454447
<ir_version: 8, opset_import: [ "": 18]>
455448
agraph (float[N] x) => (float[N] z) {
456449
two = Constant <value_float=2.0> ()
457450
z = Mul(x, x)
458-
}"""
459-
)
451+
}""")
460452
res1, res2, align, dc = compare_onnx_execution(m1, m2)
461453
text = dc.to_str(res1, res2, align)
462454
self.assertIn("CAAA Constant", text)
463455
self.assertEqual(len(align), 5)
464456

465457
def test_compare_execution_discrepancies(self):
466-
m1 = parse_model(
467-
"""
458+
m1 = parse_model("""
468459
<ir_version: 8, opset_import: [ "": 18]>
469460
agraph (float[N] x) => (float[N] z) {
470461
two = Constant <value_float=2.0> ()
471462
four = Add(two, two)
472463
z = Mul(x, x)
473-
}"""
474-
)
475-
m2 = parse_model(
476-
"""
464+
}""")
465+
m2 = parse_model("""
477466
<ir_version: 8, opset_import: [ "": 18]>
478467
agraph (float[N] x) => (float[N] z) {
479468
two = Constant <value_float=2.0> ()
480469
z = Mul(x, x)
481-
}"""
482-
)
470+
}""")
483471
res1, res2, align, dc = compare_onnx_execution(m1, m2, keep_tensor=True)
484472
text = dc.to_str(res1, res2, align)
485473
print(text)

0 commit comments

Comments
 (0)