Skip to content

Commit 621e81e

Browse files
committed
fix
1 parent b4be4df commit 621e81e

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

_unittests/ut_export/test_onnx_plug.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import onnx.helper as oh
33
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch, hide_stdout, ignore_warnings
55
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
66
from onnx_diagnostic.export.api import to_onnx
77

@@ -36,6 +36,8 @@ def make_function_proto():
3636
self.assertEqual(len(res.diffs), 1)
3737
self.assertEqual(res.diffs[0]["abs"], 0)
3838

39+
@hide_stdout()
40+
@ignore_warnings(FutureWarning)
3941
def test_onnx_plug_export(self):
4042
def _test_customsub(x, y):
4143
return x - y
@@ -61,7 +63,7 @@ def forward(self, x):
6163

6264
replacements = [
6365
EagerDirectReplacementWithOnnx(
64-
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1
66+
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1, verbose=1
6567
)
6668
]
6769

onnx_diagnostic/export/onnx_plug.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class EagerDirectReplacementWithOnnx:
5050
only tensors must be counted
5151
:param name: the name of the custom op, the function name if not specified
5252
:param kwargs: constants
53+
:param verbose: verbose level
5354
5455
Here is an example:
5556
@@ -143,6 +144,7 @@ def __init__(
143144
n_outputs: Optional[int] = None,
144145
name: Optional[str] = None,
145146
kwargs: Optional[Dict[str, Union[int, float]]] = None,
147+
verbose: int = 0,
146148
):
147149
assert isinstance(
148150
function_proto, onnx.FunctionProto
@@ -154,9 +156,13 @@ def __init__(
154156
self.function_proto = function_proto
155157
self.n_inputs = n_inputs
156158
self.n_outputs = n_outputs
157-
self.name = name or eager_fn.__qualname__.replace("<locals>", "L").replace(
158-
"<lambda>", "l"
159-
).replace(".", "_")
159+
self.name = name or (
160+
eager_fn.__name__
161+
if "<" not in eager_fn.__name__
162+
else eager_fn.__qualname__.replace("<locals>", "L")
163+
.replace("<lambda>", "l")
164+
.replace(".", "_")
165+
)
160166
self.kwargs = kwargs
161167
assert kwargs is None or all(isinstance(v, (int, float)) for v in kwargs.values()), (
162168
f"Only int or floats are allowed for kwargs={kwargs}, one of them "
@@ -179,7 +185,8 @@ def __init__(
179185
function_proto.domain == self.domain
180186
), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
181187
self.arg_names = params
182-
self.custom_op = self._registers()
188+
self.verbose = verbose
189+
self.custom_op = self._register()
183190

184191
@property
185192
def domain(self) -> str:
@@ -202,12 +209,18 @@ def __call__(self, *args):
202209
return self.torch_op(*args)
203210
return self.eager_fn(*args)
204211

205-
def _registers(self):
212+
def _register(self):
206213
"""Registers the custom op."""
207214
inputs = ", ".join([f"Tensor {p}" for p in self.arg_names])
208215
schema = f"({inputs}) -> Tensor"
209216
if self.n_outputs > 1:
210217
schema += "[]"
218+
if self.verbose:
219+
print(
220+
f"[EagerDirectReplacementWithOnnx._register] "
221+
f"'torch.ops.{self.domain}.{self.name}"
222+
)
223+
print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}")
211224
custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn)
212225
custom_def.register_kernel(None)(self.eager_fn)
213226
custom_def._abstract_fn = self.shape_fn

0 commit comments

Comments
 (0)