Skip to content

Commit 3027ebe

Browse files
committed
address reviews
1 parent 159d6b0 commit 3027ebe

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323
# Conditional logic cannot be exported into ONNX unless they refactored
2424
# to use :func:`torch.cond`. Let's start with a simple model
2525
# implementing a test.
26+
#
27+
# What you will learn:
28+
#
29+
# - How to refactor the model to use :func:`torch.cond` for exporting.
30+
# - How to export a model with control flow logic to ONNX.
31+
# - How to optimize the exported model using the ONNX optimizer.
32+
#
33+
# Prerequisites
34+
# ~~~~~~~~~~~~~
35+
#
36+
# * ``torch >= 2.6``
37+
2638

2739
import torch
2840

beginner_source/onnx/onnx_registry_tutorial.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@
2525
# * Using custom ONNX operators
2626
# * Supporting a custom PyTorch operator
2727
#
28+
# What you will learn:
29+
#
30+
# - How to override or add support for PyTorch operators in ONNX.
31+
# - How to integrate custom ONNX operators for specialized runtimes.
32+
# - How to implement and translate custom PyTorch operators to ONNX.
33+
#
34+
# Prerequisites
35+
# ~~~~~~~~~~~~~
36+
#
37+
# Before starting this tutorial, make sure you have completed the following prerequisites:
38+
#
39+
# * ``torch >= 2.6``
40+
# * The target PyTorch operator.
41+
# * Completed the
42+
# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_
43+
# before proceeding.
44+
# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__.
45+
#
2846
# Overriding the implementation of an existing PyTorch operator
2947
# -------------------------------------------------------------
3048
#
@@ -33,8 +51,8 @@
3351
# unsupported PyTorch operators to the ONNX Registry.
3452
#
3553
# .. note::
36-
# The steps to implement unsupported PyTorch operators are the same to replace the implementation of an existing
37-
# PyTorch operator with a custom implementation.
54+
# The steps to implement unsupported PyTorch operators are the same as those for replacing the implementation of an existing
55+
# PyTorch operator with a custom one.
3856
# Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leverage
3957
# this and replace the implementation of ``torch.ops.aten.add.Tensor`` with a custom implementation the same way we would
4058
# if the operator was not implemented by the ONNX exporter.
@@ -49,14 +67,6 @@
4967
# The error message indicates that the unsupported PyTorch operator is ``torch.ops.aten.add.Tensor``.
5068
# The operator is of type ``<class 'torch._ops.OpOverload'>``, and this operator is what we will use as the
5169
# target to register our custom implementation.
52-
#
53-
# To add support for an unsupported PyTorch operator or to replace the implementation for an existing one, we need:
54-
#
55-
# * The target PyTorch operator.
56-
# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__.
57-
# ONNX Script is a prerequisite for this tutorial. Please make sure you have read the
58-
# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_
59-
# before proceeding.
6070

6171
import torch
6272
import onnxscript
@@ -73,7 +83,7 @@ def forward(self, input_x, input_y):
7383

7484
# NOTE: The function signature (including param names) must match the signature of the unsupported PyTorch operator.
7585
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
76-
# NOTE: All attributes must be annotated with type hints.
86+
# All attributes must be annotated with type hints.
7787
def custom_aten_add(self, other, alpha: float = 1.0):
7888
if alpha != 1.0:
7989
alpha = op.CastLike(alpha, other)
@@ -118,7 +128,7 @@ def custom_aten_add(self, other, alpha: float = 1.0):
118128
# ---------------------------
119129
#
120130
# In this case, we create a model with standard PyTorch operators, but the runtime
121-
# (e.g. Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the
131+
# (such as Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the
122132
# existing implementation.
123133
#
124134
# In the following example, we use the ``com.microsoft.Gelu`` operator provided by ONNX Runtime,

0 commit comments

Comments
 (0)