Skip to content

Commit 2bbf6bd

Browse files
authored
Update torch_export_tutorial.py
1 parent 08b1d19 commit 2bbf6bd

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def forward(
357357
######################################################################
358358
# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails,
359359
# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is
360-
# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or [1, inf]``? we'll explain later in the
360+
# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the
361361
# 0/1 specialization section).
362362
#
363363
# Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit
@@ -390,21 +390,18 @@ def forward(
390390
######################################################################
391391
# Let's understand each of the operations and the emitted guards:
392392
#
393-
# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor.
394-
# ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``.
395-
# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export,
396-
# parameters, buffers, and constants are considered program state, which is considered static, and so this is
397-
# a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``.
393+
# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor. ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``.
394+
# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``.
398395
# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes)
399396
# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this element-wise add emits ``s3 * s4 == s5``.
400397
#
401398
# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes
402399
# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid:
403-
#
404-
# ``w: [s0, 5]``
405-
# ``x: [s2]``
406-
# ``y: [s3, s2]``
407-
# ``z: [s2*s3]``
400+
#
401+
# - ``w: [s0, 5]``
402+
# - ``x: [s2]``
403+
# - ``y: [s3, s2]``
404+
# - ``z: [s2*s3]``
408405
#
409406
# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the
410407
# corresponding inputs:
@@ -557,6 +554,7 @@ def forward(self, w, x, y, z):
557554
# This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the
558555
# dynamic behavior of the ``ExportedProgram`` produced; ``ConstraintViolation`` errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic
559556
# specifications given. For example, in the above specification, the following is asserted:
557+
#
560558
# - ``x.shape[0]`` is to have range ``[4, 256]``, and related to ``y.shape[0]`` by ``y.shape[0] == 2 * x.shape[0]``.
561559
# - ``x.shape[1]`` is static.
562560
# - ``y.shape[1]`` has range ``[2, 512]``, and is unrelated to any other dimension.
@@ -593,6 +591,7 @@ def forward(self, x, y):
593591
# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
594592
#
595593
# Lastly, there's couple nice-to-knows about the options for specification:
594+
#
596595
# - ``None`` is a good option for static behavior:
597596
# - ``dynamic_shapes=None`` (default) exports with the entire model being static.
598597
# - specifying ``None`` at an input-level exports with all tensor dimensions static, and alternatively is also required for non-tensor inputs.

0 commit comments

Comments
 (0)