Skip to content

Commit 2ab6859

Browse files
committed
fix plugs
1 parent 51b8db5 commit 2ab6859

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.8.4
55
+++++
66

7+
* :pr:`336`: implements versioned onnx plugs
8+
79
0.8.3
810
+++++
911

onnx_diagnostic/export/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def find_method(self, name: Any):
149149

150150
if exporter in ("dynamo", "onnx-dynamo"):
151151
import os
152+
from ..helpers import flatten_object
152153
import onnxscript.rewriter.ort_fusions as ort_fusions
153154

154155
assert (
@@ -180,7 +181,12 @@ def find_method(self, name: Any):
180181
import onnx_ir as ir
181182
import onnx_ir.passes.common as common_passes
182183

183-
irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
184+
irfunctions = [
185+
ir.from_proto(
186+
plug.get_function_proto(*flatten_object((args, kwargs), drop_keys=True))
187+
)
188+
for plug in onnx_plugs
189+
]
184190
for func in irfunctions:
185191
epo.model.functions[func.identifier()] = func
186192
if inline:

onnx_diagnostic/export/onnx_plug.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,10 @@ def _check_protos(self, params):
210210
# multiple protos
211211
assert all(
212212
self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
213-
), f"Output mismatch n_inputs={self.n_inputs} but one verion is wrong"
213+
), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong"
214214
assert all(
215215
self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
216-
), f"Output mismatch n_outputs={self.n_outputs} but one verion is wrong"
216+
), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong"
217217
assert all(
218218
v.domain == self.domain for v in self._function_proto_versioned.values()
219219
), f"Function domain must be {self.domain!r} but it is different in one version"

0 commit comments

Comments
 (0)