25
25
# * Using custom ONNX operators
26
26
# * Supporting a custom PyTorch operator
27
27
#
28
+ # What you will learn:
29
+ #
30
+ # - How to override or add support for PyTorch operators in ONNX.
31
+ # - How to integrate custom ONNX operators for specialized runtimes.
32
+ # - How to implement and translate custom PyTorch operators to ONNX.
33
+ #
34
+ # Prerequisites
35
+ # ~~~~~~~~~~~~~
36
+ #
37
+ # Before starting this tutorial, make sure you have completed the following prerequisites:
38
+ #
39
+ # * ``torch >= 2.6``
40
+ # * The target PyTorch operator.
41
+ # * Completed the
42
+ # `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_
43
+ # before proceeding.
44
+ # * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__.
45
+ #
28
46
# Overriding the implementation of an existing PyTorch operator
29
47
# -------------------------------------------------------------
30
48
#
33
51
# unsupported PyTorch operators to the ONNX Registry.
34
52
#
35
53
# .. note::
36
- # The steps to implement unsupported PyTorch operators are the same to replace the implementation of an existing
37
- # PyTorch operator with a custom implementation.
54
+ # The steps to implement unsupported PyTorch operators are the same as those for replacing the implementation of an existing
55
+ # PyTorch operator with a custom one.
38
56
# Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leverage
39
57
# this and replace the implementation of ``torch.ops.aten.add.Tensor`` with a custom implementation the same way we would
40
58
# if the operator was not implemented by the ONNX exporter.
49
67
# The error message indicates that the unsupported PyTorch operator is ``torch.ops.aten.add.Tensor``.
50
68
# The operator is of type ``<class 'torch._ops.OpOverload'>``, and this operator is what we will use as the
51
69
# target to register our custom implementation.
52
- #
53
- # To add support for an unsupported PyTorch operator or to replace the implementation for an existing one, we need:
54
- #
55
- # * The target PyTorch operator.
56
- # * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__.
57
- # ONNX Script is a prerequisite for this tutorial. Please make sure you have read the
58
- # `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_
59
- # before proceeding.
60
70
61
71
import torch
62
72
import onnxscript
@@ -73,7 +83,7 @@ def forward(self, input_x, input_y):
73
83
74
84
# NOTE: The function signature (including param names) must match the signature of the unsupported PyTorch operator.
75
85
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
76
- # NOTE: All attributes must be annotated with type hints.
86
+ # All attributes must be annotated with type hints.
77
87
def custom_aten_add (self , other , alpha : float = 1.0 ):
78
88
if alpha != 1.0 :
79
89
alpha = op .CastLike (alpha , other )
@@ -118,7 +128,7 @@ def custom_aten_add(self, other, alpha: float = 1.0):
118
128
# ---------------------------
119
129
#
120
130
# In this case, we create a model with standard PyTorch operators, but the runtime
121
- # (e.g. Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the
131
+ # (such as Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the
122
132
# existing implementation.
123
133
#
124
134
# In the following example, we use the ``com.microsoft.Gelu`` operator provided by ONNX Runtime,
0 commit comments