diff --git a/_doc/examples/plot_onnx_diff.py b/_doc/examples/plot_onnx_diff.py index 7b6ecdf..14e0e72 100644 --- a/_doc/examples/plot_onnx_diff.py +++ b/_doc/examples/plot_onnx_diff.py @@ -19,7 +19,6 @@ from onnx_array_api.reference import compare_onnx_execution from onnx_array_api.plotting.text_plot import onnx_simple_text_plot - data = load_iris() X_train, X_test = train_test_split(data.data) model = GaussianMixture() diff --git a/_doc/examples/plot_optimization.py b/_doc/examples/plot_optimization.py index c78419b..cb935f1 100644 --- a/_doc/examples/plot_optimization.py +++ b/_doc/examples/plot_optimization.py @@ -29,7 +29,6 @@ from onnx_array_api.ext_test_case import measure_time from onnx_array_api.ort.ort_optimizers import ort_optimized_model - filename = example_path("data/small.onnx") optimized = filename + ".optimized.onnx" diff --git a/_doc/examples/plot_profiling.py b/_doc/examples/plot_profiling.py index 201de95..1d93566 100644 --- a/_doc/examples/plot_profiling.py +++ b/_doc/examples/plot_profiling.py @@ -25,7 +25,6 @@ from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile from onnx_array_api.plotting.stat_plot import plot_ort_profile - suffix = "" filename = example_path(f"data/small{suffix}.onnx") optimized = filename + ".optimized.onnx" diff --git a/_unittests/ut_graph_api/test_graph_builder.py b/_unittests/ut_graph_api/test_graph_builder.py index 9e6229b..ebf12ca 100644 --- a/_unittests/ut_graph_api/test_graph_builder.py +++ b/_unittests/ut_graph_api/test_graph_builder.py @@ -18,29 +18,25 @@ def call_optimizer(self, onx): return gr.to_onnx() def test_remove_unused_nodes(self): - model = onnx.parser.parse_model( - """ + model = onnx.parser.parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () four = Add(two, two) z = Mul(x, x) - }""" - ) + }""") onx = self.call_optimizer(model) self.assertEqual(len(onx.graph.node), 1) self.assertEqual(onx.graph.node[0].op_type, "Mul") def test_initializers(self): - model = onnx.parser.parse_model( - """ + model = onnx.parser.parse_model(""" agraph (float[N] x) => (float[N] z) { four = Add(two, two) z = Mul(x, x) - }""" - ) + }""") self.assertEqual(len(model.graph.initializer), 1) onx = self.call_optimizer(model) self.assertEqual(len(onx.graph.node), 1) @@ -48,14 +44,12 @@ def test_initializers(self): self.assertEqual(len(onx.graph.initializer), 0) def test_keep_unused_outputs(self): - model = onnx.parser.parse_model( - """ + model = onnx.parser.parse_model(""" agraph (float[N] x) => (float[M] z) { w1, w2, w3 = Split (x) z = Mul(w3, w3) - }""" - ) + }""") onx = self.call_optimizer(model) self.assertEqual(len(onx.graph.node), 2) self.assertEqual(onx.graph.node[0].op_type, "Split") @@ -381,30 +375,26 @@ def test_make_nodes_noprefix(self): self.assertEqualArray(expected, got[0]) def test_node_pattern(self): - model = onnx.parser.parse_model( - """ + model = onnx.parser.parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () four = Add(two, two) z = Mul(x, four) - }""" - ) + }""") gr = GraphBuilder(model) p = gr.np(index=0) r = repr(p) self.assertEqual("NodePattern(index=0, op_type=None, name=None)", r) def test_update_node_attribute(self): - model = onnx.parser.parse_model( - """ + model = onnx.parser.parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () four = Add(two, two) z = Mul(x, four) - }""" - ) + }""") gr = GraphBuilder(model) self.assertEqual(len(gr.nodes), 3) m = gr.update_attribute(gr.np(op_type="Constant"), value_float=float(1)) @@ -416,15 +406,13 @@ def test_update_node_attribute(self): self.assertIn("f: 1", str(node)) def test_delete_node_attribute(self): - model = onnx.parser.parse_model( - """ + model = onnx.parser.parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () four = Add(two, two) z = Mul(x, four) - }""" - ) + }""") gr = GraphBuilder(model) self.assertEqual(len(gr.nodes), 3) m = gr.update_attribute( diff --git a/_unittests/ut_npx/test_sklearn_array_api.py b/_unittests/ut_npx/test_sklearn_array_api.py index 9c0d56f..924d819 100644 --- a/_unittests/ut_npx/test_sklearn_array_api.py +++ b/_unittests/ut_npx/test_sklearn_array_api.py @@ -7,7 +7,6 @@ from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor - DEFAULT_OPSET = onnx_opset_version() diff --git a/_unittests/ut_ort/test_ort_optimizer.py b/_unittests/ut_ort/test_ort_optimizer.py index bf07d0f..f2b7911 100644 --- a/_unittests/ut_ort/test_ort_optimizer.py +++ b/_unittests/ut_ort/test_ort_optimizer.py @@ -6,7 +6,6 @@ from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.ort.ort_optimizers import ort_optimized_model - DEFAULT_OPSET = onnx_opset_version() diff --git a/_unittests/ut_ort/test_sklearn_array_api_ort.py b/_unittests/ut_ort/test_sklearn_array_api_ort.py index f50fce1..12a5c93 100644 --- a/_unittests/ut_ort/test_sklearn_array_api_ort.py +++ b/_unittests/ut_ort/test_sklearn_array_api_ort.py @@ -7,7 +7,6 @@ from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor - DEFAULT_OPSET = onnx_opset_version() diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py index 420779e..e8ca454 100644 --- a/_unittests/ut_plotting/test_graphviz.py +++ b/_unittests/ut_plotting/test_graphviz.py @@ -13,15 +13,13 @@ class TestGraphviz(ExtTestCase): @classmethod def _get_graph(cls): - return onnx.parser.parse_model( - """ + return onnx.parser.parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () four = Add(two, two) z = Mul(x, x) - }""" - ) + }""") @skipif_ci_windows("graphviz not installed") @skipif_ci_apple("graphviz not installed") diff --git a/_unittests/ut_plotting/test_text_plot.py b/_unittests/ut_plotting/test_text_plot.py index 5844ff0..fc2a637 100644 --- a/_unittests/ut_plotting/test_text_plot.py +++ b/_unittests/ut_plotting/test_text_plot.py @@ -69,8 +69,7 @@ def test_onnx_text_plot_tree_cls_2(self): model_def = load(f) res = onnx_text_plot_tree(model_def.graph.node[0]) self.assertIn("n_classes=3", res) - expected = textwrap.dedent( - """ + expected = textwrap.dedent(""" n_classes=3 n_trees=1 ---- @@ -92,8 +91,7 @@ def test_onnx_text_plot_tree_cls_2(self): -f 0:0 1:0 2:1 +f 0:0 1:1 2:0 +f 0:1 1:0 2:0 - """ - ).strip(" \n\r") + """).strip(" \n\r") res = res.replace("np.float32(", "").replace(")", "") self.assertEqual(expected, res.strip(" \n\r")) @@ -104,8 +102,7 @@ def test_onnx_simple_text_plot_kmeans(self): model.fit(x) onx = to_onnx(model, x.astype(numpy.float32), target_opset=15) text = onnx_simple_text_plot(onx) - expected1 = textwrap.dedent( - """ + expected1 = textwrap.dedent(""" ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0 Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0 Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0 @@ -113,10 +110,8 @@ def test_onnx_simple_text_plot_kmeans(self): Add(Ad_Addcst, Ad_C01) -> Ad_C0 Sqrt(Ad_C0) -> scores ArgMin(Ad_C0, axis=1, keepdims=0) -> label - """ - ).strip(" \n") - expected2 = textwrap.dedent( - """ + """).strip(" \n") + expected2 = textwrap.dedent(""" ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0 Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0 Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0 @@ -124,10 +119,8 @@ def test_onnx_simple_text_plot_kmeans(self): Add(Ad_Addcst, Ad_C01) -> Ad_C0 Sqrt(Ad_C0) -> scores ArgMin(Ad_C0, axis=1, keepdims=0) -> label - """ - ).strip(" \n") - expected3 = textwrap.dedent( - """ + """).strip(" \n") + expected3 = textwrap.dedent(""" ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0 Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0 Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0 @@ -135,8 +128,7 @@ def test_onnx_simple_text_plot_kmeans(self): Add(Ad_Addcst, Ad_C01) -> Ad_C0 ArgMin(Ad_C0, axis=1, keepdims=0) -> label Sqrt(Ad_C0) -> scores - """ - ).strip(" \n") + """).strip(" \n") if expected1 not in text and expected2 not in text and expected3 not in text: raise AssertionError(f"Unexpected value:\n{text}") @@ -165,8 +157,7 @@ def test_onnx_simple_text_plot_toy(self): {"X": x.astype(numpy.float32)}, outputs={"Y": x}, target_opset=15 ) text = onnx_simple_text_plot(onx, verbose=False) - expected = textwrap.dedent( - """ + expected = textwrap.dedent(""" Add(X, Ad_Addcst) -> Ad_C0 Abs(Ad_C0) -> Ab_Y0 Identity(Ad_Addcst) -> Su_Subcst @@ -174,8 +165,7 @@ def test_onnx_simple_text_plot_toy(self): Abs(Su_C0) -> Ab_Y02 Div(Ab_Y0, Ab_Y02) -> Di_C0 Abs(Di_C0) -> Y - """ - ).strip(" \n") + """).strip(" \n") self.assertIn(expected, text) text2, out, err = self.capture(lambda: onnx_simple_text_plot(onx, verbose=True)) self.assertEqual(text, text2) @@ -188,11 +178,9 @@ def test_onnx_simple_text_plot_leaky(self): {"X": FloatTensorType()}, outputs={"Y": FloatTensorType()}, target_opset=15 ) text = onnx_simple_text_plot(onx) - expected = textwrap.dedent( - """ + expected = textwrap.dedent(""" LeakyRelu(X, alpha=0.50) -> Y - """ - ).strip(" \n") + """).strip(" \n") self.assertIn(expected, text) def test_onnx_text_plot_io(self): @@ -201,11 +189,9 @@ def test_onnx_text_plot_io(self): {"X": FloatTensorType()}, outputs={"Y": FloatTensorType()}, target_opset=15 ) text = onnx_text_plot_io(onx) - expected = textwrap.dedent( - """ + expected = textwrap.dedent(""" input: - """ - ).strip(" \n") + """).strip(" \n") self.assertIn(expected, text) def test_onnx_simple_text_plot_if(self): @@ -244,11 +230,9 @@ def test_onnx_simple_text_plot_if(self): {"x1": x1, "x2": x2}, target_opset=opv, outputs=[("y", FloatTensorType())] ) text = onnx_simple_text_plot(model_def) - expected = textwrap.dedent( - """ + expected = textwrap.dedent(""" input: - """ - ).strip(" \n") + """).strip(" \n") self.assertIn(expected, text) self.assertIn("If(Gr_C0, else_branch=G1, then_branch=G2)", text) diff --git a/_unittests/ut_reference/test_evaluator_yield.py b/_unittests/ut_reference/test_evaluator_yield.py index 605c1f8..a954b84 100644 --- a/_unittests/ut_reference/test_evaluator_yield.py +++ b/_unittests/ut_reference/test_evaluator_yield.py @@ -431,55 +431,43 @@ def test_distance_sequence_str(self): 005=|RESULTfloat322:2x2CEIOLinearRegressioY1|RESULTfloat322:2x2CEIOLinearRegressioY1 006~|RESULTfloat322:2x2CEIOAbsY|RESULTfloat322:2x3CEIPAbsZ 007~|OUTPUTfloat322:2x2CEIOY|OUTPUTfloat322:2x2CEIPY - """.replace( - " ", "" - ).strip( - "\n " - ) + """.replace(" ", "").strip("\n ") self.maxDiff = None self.assertEqual(expected, text.replace(" ", "").strip("\n")) def test_compare_execution(self): - m1 = parse_model( - """ + m1 = parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () four = Add(two, two) z = Mul(x, x) - }""" - ) - m2 = parse_model( - """ + }""") + m2 = parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () z = Mul(x, x) - }""" - ) + }""") res1, res2, align, dc = compare_onnx_execution(m1, m2) text = dc.to_str(res1, res2, align) self.assertIn("CAAA Constant", text) self.assertEqual(len(align), 5) def test_compare_execution_discrepancies(self): - m1 = parse_model( - """ + m1 = parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () four = Add(two, two) z = Mul(x, x) - }""" - ) - m2 = parse_model( - """ + }""") + m2 = parse_model(""" agraph (float[N] x) => (float[N] z) { two = Constant () z = Mul(x, x) - }""" - ) + }""") res1, res2, align, dc = compare_onnx_execution(m1, m2, keep_tensor=True) text = dc.to_str(res1, res2, align) print(text) diff --git a/_unittests/ut_translate_api/test_translate.py b/_unittests/ut_translate_api/test_translate.py index d80f5e3..45dc896 100644 --- a/_unittests/ut_translate_api/test_translate.py +++ b/_unittests/ut_translate_api/test_translate.py @@ -33,8 +33,7 @@ def test_exp(self): self.assertEqualArray(np.exp(a), got) code = translate(onx) - expected = dedent( - """ + expected = dedent(""" ( start(opset=19) .vin('X', elem_type=onnx.TensorProto.FLOAT) @@ -44,8 +43,7 @@ def test_exp(self): .bring('Y') .vout(elem_type=onnx.TensorProto.FLOAT) .to_onnx() - )""" - ).strip("\n") + )""").strip("\n") self.assertEqual(expected, code) onx2 = ( @@ -81,8 +79,7 @@ def test_transpose(self): self.assertEqualArray(a.reshape((-1, 1)).T, got) code = translate(onx) - expected = dedent( - """ + expected = dedent(""" ( start(opset=19) .cst(np.array([-1, 1], dtype=np.int64)) @@ -97,8 +94,7 @@ def test_transpose(self): .bring('Y') .vout(elem_type=onnx.TensorProto.FLOAT) .to_onnx() - )""" - ).strip("\n") + )""").strip("\n") self.assertEqual(expected, code) def test_topk_reverse(self): @@ -121,8 +117,7 @@ def test_topk_reverse(self): self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1]) code = translate(onx) - expected = dedent( - """ + expected = dedent(""" ( start(opset=19) .vin('X', elem_type=onnx.TensorProto.FLOAT) @@ -135,8 +130,7 @@ def test_topk_reverse(self): .bring('Indices') .vout(elem_type=onnx.TensorProto.FLOAT) .to_onnx() - )""" - ).strip("\n") + )""").strip("\n") self.assertEqual(expected, code) def test_export_if(self): @@ -173,8 +167,7 @@ def test_export_if(self): "g().cst(np.array([1], dtype=np.int64)).rename('Z')." "bring('Z').vout(elem_type=onnx.TensorProto.FLOAT)" ) - expected = dedent( - f""" + expected = dedent(f""" ( start(opset=19) .cst(np.array([0.0], dtype=np.float32)) @@ -192,8 +185,7 @@ def test_export_if(self): .bring('W') .vout(elem_type=onnx.TensorProto.FLOAT) .to_onnx() - )""" - ).strip("\n") + )""").strip("\n") self.maxDiff = None self.assertEqual(expected, code) @@ -209,8 +201,7 @@ def test_aionnxml(self): .to_onnx() ) code = translate(onx) - expected = dedent( - """ + expected = dedent(""" ( start(opset=19, opsets={'ai.onnx.ml': 3}) .cst(np.array([-1, 1], dtype=np.int64)) @@ -225,8 +216,7 @@ def test_aionnxml(self): .bring('Y') .vout(elem_type=onnx.TensorProto.FLOAT) .to_onnx() - )""" - ).strip("\n") + )""").strip("\n") self.maxDiff = None self.assertEqual(expected, code) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 184708b..e2a9235 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -12,7 +12,6 @@ from onnx_array_api.translate_api import translate, Translater from onnx_array_api.translate_api.builder_emitter import BuilderEmitter - OPSET_API = min(19, onnx_opset_version() - 1) @@ -30,9 +29,7 @@ def test_exp(self): self.assertEqualArray(np.exp(a), got) code = translate(onx, api="builder") - expected = ( - dedent( - """ + expected = dedent(""" def light_api( op: "GraphBuilder", X: "FLOAT[]", @@ -46,11 +43,7 @@ def light_api( light_api(g.op, "X") g.make_tensor_output("Y", onnx.TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() - """ - ) - .strip("\n") - .replace("__SUFFIX__", ", is_dimension=False, indexed=False") - ) + """).strip("\n").replace("__SUFFIX__", ", is_dimension=False, indexed=False") self.assertEqual(expected, code.strip("\n")) def light_api( @@ -86,8 +79,7 @@ def test_zdoc(self): ) code = translate(onx, api="builder") expected = ( - dedent( - """ + dedent(""" def light_api( op: "GraphBuilder", X: "FLOAT[]", @@ -103,8 +95,7 @@ def light_api( light_api(g.op, "X") g.make_tensor_output("Y", onnx.TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() - """ - ) + """) .strip("\n") .replace("__SUFFIX__", ", is_dimension=False, indexed=False") ) @@ -141,9 +132,7 @@ def test_exp_f(self): tr = Translater(onx, emitter=BuilderEmitter("mm")) code = tr.export(as_str=True) - expected = ( - dedent( - """ + expected = dedent(""" def light_api( op: "GraphBuilder", X: "FLOAT[]", @@ -163,11 +152,7 @@ def mm() -> "ModelProto": model = mm() - """ - ) - .strip("\n") - .replace("__SUFFIX__", ", is_dimension=False, indexed=False") - ) + """).strip("\n").replace("__SUFFIX__", ", is_dimension=False, indexed=False") self.assertEqual(expected, code.strip("\n")) def light_api( @@ -232,9 +217,7 @@ def test_local_function(self): tr = Translater(onnx_model, emitter=BuilderEmitter("mm")) code = tr.export(as_str=True) - expected = ( - dedent( - """ + expected = dedent(""" def example( op: "GraphBuilder", X: "FLOAT[, ]", @@ -273,11 +256,7 @@ def mm() -> "ModelProto": model = mm() - """ - ) - .strip("\n") - .replace("__SUFFIX__", ", is_dimension=False, indexed=False") - ) + """).strip("\n").replace("__SUFFIX__", ", is_dimension=False, indexed=False") self.assertEqual(expected, code.strip("\n")) diff --git a/_unittests/ut_translate_api/test_translate_classic.py b/_unittests/ut_translate_api/test_translate_classic.py index c0d7954..8b152cb 100644 --- a/_unittests/ut_translate_api/test_translate_classic.py +++ b/_unittests/ut_translate_api/test_translate_classic.py @@ -51,8 +51,7 @@ def test_exp(self): code = translate(onx, api="onnx") - expected = dedent( - """ + expected = dedent(""" opset_imports = [ oh.make_opsetid('', 19), ] @@ -83,8 +82,7 @@ def test_exp(self): graph, functions=functions, opset_imports=opset_imports - )""" - ).strip("\n") + )""").strip("\n") self.maxDiff = None self.assertEqual(expected, code) @@ -121,8 +119,7 @@ def test_transpose(self): self.assertEqualArray(a.reshape((-1, 1)).T, got) code = translate(onx, api="onnx") - expected = dedent( - """ + expected = dedent(""" opset_imports = [ oh.make_opsetid('', 19), ] @@ -167,8 +164,7 @@ def test_transpose(self): graph, functions=functions, opset_imports=opset_imports - )""" - ).strip("\n") + )""").strip("\n") self.maxDiff = None self.assertEqual(expected, code) @@ -190,8 +186,7 @@ def test_transpose_short(self): self.assertEqualArray(a.reshape((-1, 1)).T, got) code = translate(onx, api="onnx-short") - expected = dedent( - """ + expected = dedent(""" opset_imports = [ oh.make_opsetid('', 19), ] @@ -236,8 +231,7 @@ def test_transpose_short(self): graph, functions=functions, opset_imports=opset_imports - )""" - ).strip("\n") + )""").strip("\n") self.maxDiff = None self.assertEqual(expected, code) @@ -261,8 +255,7 @@ def test_topk_reverse(self): self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1]) code = translate(onx, api="onnx") - expected = dedent( - """ + expected = dedent(""" opset_imports = [ oh.make_opsetid('', 19), ] @@ -298,8 +291,7 @@ def test_topk_reverse(self): graph, functions=functions, opset_imports=opset_imports - )""" - ).strip("\n") + )""").strip("\n") self.maxDiff = None self.assertEqual(expected, code) @@ -329,8 +321,7 @@ def test_aionnxml(self): .to_onnx() ) code = translate(onx, api="onnx") - expected = dedent( - """ + expected = dedent(""" opset_imports = [ oh.make_opsetid('', 19), oh.make_opsetid('ai.onnx.ml', 3), @@ -377,8 +368,7 @@ def test_aionnxml(self): graph, functions=functions, opset_imports=opset_imports - )""" - ).strip("\n") + )""").strip("\n") self.maxDiff = None self.assertEqual(expected, code) diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py index d1eac62..6ba70e7 100644 --- a/onnx_array_api/_command_lines_parser.py +++ b/onnx_array_api/_command_lines_parser.py @@ -15,16 +15,14 @@ def get_main_parser() -> ArgumentParser: parser.add_argument( "cmd", choices=["translate", "compare", "replace"], - help=dedent( - """ + help=dedent(""" Selects a command. 'translate' exports an onnx graph into a piece of code replicating it, 'compare' compares the execution of two onnx models, 'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter - """ - ), + """), ) return parser @@ -32,12 +30,10 @@ def get_main_parser() -> ArgumentParser: def get_parser_translate() -> ArgumentParser: parser = ArgumentParser( prog="translate", - description=dedent( - """ + description=dedent(""" Translates an onnx model into a piece of code to replicate it. The result is printed on the standard output. - """ - ), + """), epilog="This is mostly used to write unit tests without adding " "an onnx file to the repository.", ) @@ -71,11 +67,9 @@ def _cmd_translate(argv: List[Any]): def get_parser_compare() -> ArgumentParser: parser = ArgumentParser( prog="compare", - description=dedent( - """ + description=dedent(""" Compares the execution of two onnx models. - """ - ), + """), epilog="This is used when two models are different but " "should produce the same results.", ) @@ -148,12 +142,10 @@ def _cmd_compare(argv: List[Any]): def get_parser_replace() -> ArgumentParser: parser = ArgumentParser( prog="translate", - description=dedent( - """ + description=dedent(""" Replaces constants and initializes by ConstOfShape or any other nodes to make the model smaller. - """ - ), + """), epilog="This is mostly used to write unit tests without adding " "a big file to the repository.", ) diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index 9b67b4b..1a3b0ec 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -6,7 +6,6 @@ from ..npx.npx_types import DType from ..npx import npx_functions - supported_functions = [ "abs", "absolute", diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 3e92472..f62ad28 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -30,7 +30,6 @@ zeros as generic_zeros, ) - Array = type(array_api_strict.ones((1,))) if array_api_strict else None diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index 91034f7..90f98af 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -773,7 +773,7 @@ def to_onnx( packed = self._function_to_onnx( var.onnx_op[1], len(var.inputs), var.n_var_outputs ) - (onx_fn, in_types, out_types, att_types) = packed + onx_fn, in_types, out_types, att_types = packed domop = (onx_fn.domain, onx_fn.name) for inp, index, dt in zip(var.inputs, var.input_indices, in_types): diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index 89b5a84..e326f12 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -12,7 +12,6 @@ from .ops.op_quick_gelu import QuickGelu from .ops.op_scatter_elements import ScatterElements - logger = getLogger("onnx-array-api-eval") diff --git a/onnx_array_api/translate_api/__init__.py b/onnx_array_api/translate_api/__init__.py index fd60de1..64322d3 100644 --- a/onnx_array_api/translate_api/__init__.py +++ b/onnx_array_api/translate_api/__init__.py @@ -8,34 +8,28 @@ def translate_header(api: str = "light"): """Returns the necessary header for each api.""" if api == "light": - return textwrap.dedent( - """ + return textwrap.dedent(""" import numpy as np import ml_dtypes from onnx_array_api.light_api import start from onnx_array_api.translate_api import translate - """ - ) + """) if api == "onnx": - return textwrap.dedent( - """ + return textwrap.dedent(""" import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh from onnx_array_api.translate_api.make_helper import make_node_extended - """ - ) + """) if api == "builder": - return textwrap.dedent( - """ + return textwrap.dedent(""" import numpy as np import ml_dtypes import onnx from onnx_array_api.graph_api import GraphBuilder - """ - ) + """) raise ValueError(f"Unexpected value {api!r} for api.") diff --git a/onnx_array_api/validation/diff.py b/onnx_array_api/validation/diff.py index aa73078..c17f570 100644 --- a/onnx_array_api/validation/diff.py +++ b/onnx_array_api/validation/diff.py @@ -8,8 +8,7 @@ def _get_diff_template(): import jinja2 - tpl = textwrap.dedent( - """ + tpl = textwrap.dedent("""
- """ - ) + """) path = os.path.abspath(os.path.dirname(__file__)) path = path.replace("\\", "/") path = f"file://{path}"