We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b5aa71b commit 7747de3Copy full SHA for 7747de3
onnx_diagnostic/export/onnx_plug.py
@@ -204,11 +204,11 @@ def torch_op(self) -> Callable:
204
"Returns ``torch.ops.onny_plug.<name>"
205
return getattr(getattr(torch.ops, self.domain), self.name).default
206
207
- def __call__(self, *args):
+ def __call__(self, *args, **kwargs):
208
"""Calls eager_fn or shape_fn if the model is being exported."""
209
if is_exporting():
210
return self.torch_op(*args)
211
- return self.eager_fn(*args)
+ return self.eager_fn(*args, **kwargs)
212
213
def _register(self):
214
"""Registers the custom op."""
0 commit comments