Skip to content

Commit 51fbbac

Browse files
committed
export tutorial
1 parent 6dc1122 commit 51fbbac

File tree

1 file changed

+80
-231
lines changed

1 file changed

+80
-231
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 80 additions & 231 deletions
Original file line numberDiff line numberDiff line change
@@ -510,246 +510,95 @@ def forward(self, w, x, y, z):
510510
# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above.
511511
#
512512
# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier.
513-
#
514-
515-
516-
517-
518-
519-
520-
521-
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
522-
# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
523-
# example inputs given to the initial ``torch.export.export()`` call.
524-
# If we try to run the ``ExportedProgram`` in the example below with a tensor
525-
# with a different shape, we get an error:
526-
527-
class MyModule2(torch.nn.Module):
528-
def __init__(self):
529-
super().__init__()
530-
self.lin = torch.nn.Linear(100, 10)
531-
532-
def forward(self, x, y):
533-
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
534-
535-
mod2 = MyModule2()
536-
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
537-
538-
try:
539-
exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
540-
except Exception:
541-
tb.print_exc()
542-
543-
######################################################################
544-
# We can relax this constraint using the ``dynamic_shapes`` argument of
545-
# ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim``
546-
# (`documentation <https://pytorch.org/docs/main/export.html#torch.export.Dim>`__),
547-
# which dimensions of the input tensors are dynamic.
548-
#
549-
# For each tensor argument of the input callable, we can specify a mapping from the dimension
550-
# to a ``torch.export.Dim``.
551-
# A ``torch.export.Dim`` is essentially a named symbolic integer with optional
552-
# minimum and maximum bounds.
553-
#
554-
# Then, the format of ``torch.export.export()``'s ``dynamic_shapes`` argument is a mapping
555-
# from the input callable's tensor argument names, to dimension --> dim mappings as described above.
556-
# If there is no ``torch.export.Dim`` given to a tensor argument's dimension, then that dimension is
557-
# assumed to be static.
558-
#
559-
# The first argument of ``torch.export.Dim`` is the name for the symbolic integer, used for debugging.
560-
# Then we can specify an optional minimum and maximum bound (inclusive). Below, we show a usage example.
561-
#
562-
# In the example below, our input
563-
# ``inp1`` has an unconstrained first dimension, but the size of the second
564-
# dimension must be in the interval [4, 18].
565-
566-
from torch.export import Dim
567-
568-
inp1 = torch.randn(10, 10, 2)
569-
570-
class DynamicShapesExample1(torch.nn.Module):
571-
def forward(self, x):
572-
x = x[:, 2:]
573-
return torch.relu(x)
574-
575-
inp1_dim0 = Dim("inp1_dim0")
576-
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
577-
dynamic_shapes1 = {
578-
"x": {0: inp1_dim0, 1: inp1_dim1},
579-
}
580-
581-
exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)
582-
583-
print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))
584-
585-
try:
586-
exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
587-
except Exception:
588-
tb.print_exc()
589-
590-
try:
591-
exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
592-
except Exception:
593-
tb.print_exc()
594-
595-
try:
596-
exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
597-
except Exception:
598-
tb.print_exc()
599-
600-
######################################################################
601-
# Note that if our example inputs to ``torch.export`` do not satisfy the constraints
602-
# given by ``dynamic_shapes``, then we get an error.
603-
604-
inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)
605-
dynamic_shapes1_bad = {
606-
"x": {0: inp1_dim0, 1: inp1_dim1_bad},
513+
# The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that
514+
# don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should
515+
# specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens
516+
# at runtime when we export this linear layer:
517+
518+
ep = export(
519+
torch.nn.Linear(4, 3),
520+
(torch.randn(1, 4),),
521+
dynamic_shapes={
522+
"input": (Dim.AUTO, Dim.STATIC),
523+
},
524+
)
525+
ep.module()(torch.randn(2, 4))
526+
527+
# So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the
528+
# low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic
529+
# dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a "discovery" process, summarizing these guards
530+
# and presenting what export believes is the overall dynamic behavior of the program. The drawback of this design appears once the user has stronger expectations or
531+
# beliefs about the dynamic behavior of these models - maybe there is a strong desire on dynamism and specializations on particular dimensions are to be avoided at
532+
# all costs, or maybe we just want to catch changes in dynamic behavior with changes to the original model code, or possibly underlying decompositions or meta-kernels.
533+
# These changes won't be detected and the ``export()`` call will most likely succeed, unless tests are in place that check the resulting ExportedProgram representation.
534+
#
535+
# For such cases, our stance is to recommend the "traditional" way of specifying dynamic shapes, which longer-term users of export might be familiar with: named ``Dims``:
536+
537+
dx = Dim("dx", min=4, max=256)
538+
dh = Dim("dh", max=512)
539+
dynamic_shapes = {
540+
"x": (dx, None),
541+
"y": (2 * dx, dh),
607542
}
608543

609-
try:
610-
export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
611-
except Exception:
612-
tb.print_exc()
613-
614-
######################################################################
615-
# We can enforce that equalities between dimensions of different tensors
616-
# by using the same ``torch.export.Dim`` object, for example, in matrix multiplication:
617-
618-
inp2 = torch.randn(4, 8)
619-
inp3 = torch.randn(8, 2)
620-
621-
class DynamicShapesExample2(torch.nn.Module):
622-
def forward(self, x, y):
623-
return x @ y
624-
625-
inp2_dim0 = Dim("inp2_dim0")
626-
inner_dim = Dim("inner_dim")
627-
inp3_dim1 = Dim("inp3_dim1")
544+
# This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the
545+
# dynamic behavior of the ExportedProgram produced; ConstraintViolation errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic
546+
# specifications given. For example, in the above specification, the following is asserted:
547+
# - ``x.shape[0]`` is to have range ``[4, 256]``, and related to ``y.shape[0]`` by ``y.shape[0] == 2 * x.shape[0]``.
548+
# - ``x.shape[1]`` is static.
549+
# - ``y.shape[1]`` has range ``[2, 512]``, and is unrelated to any other dimension.
550+
#
551+
# In this design, we allow relations between dimensions to be specified with univariate linear expressions: ``A * dim + B`` can be specified for any dimension. This allows users
552+
# to specify more complex constraints like integer divisibility for dynamic dimensions:
628553

629-
dynamic_shapes2 = {
630-
"x": {0: inp2_dim0, 1: inner_dim},
631-
"y": {0: inner_dim, 1: inp3_dim1},
554+
dx = Dim("dx", min=4, max=512)
555+
dynamic_shapes = {
556+
"x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4.
632557
}
633558

634-
exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)
635-
636-
print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))
637-
638-
try:
639-
exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
640-
except Exception:
641-
tb.print_exc()
642-
643-
######################################################################
644-
# We can also describe one dimension in terms of other. There are some
645-
# restrictions to how detailed we can specify one dimension in terms of another,
646-
# but generally, those in the form of ``A * Dim + B`` should work.
559+
# One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing.
560+
# That would lead to ConstraintViolation errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between
561+
# dimensions 0 of ``x`` and ``y``, and requires dimension 1 to be static.
647562

648-
class DerivedDimExample1(torch.nn.Module):
563+
class Foo(torch.nn.Module):
649564
def forward(self, x, y):
650-
return x + y[1:]
651-
652-
foo = DerivedDimExample1()
653-
654-
x, y = torch.randn(5), torch.randn(6)
655-
dimx = torch.export.Dim("dimx", min=3, max=6)
656-
dimy = dimx + 1
657-
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})
658-
659-
derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)
660-
661-
print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))
662-
663-
try:
664-
derived_dim_example1.module()(torch.randn(4), torch.randn(6))
665-
except Exception:
666-
tb.print_exc()
667-
668-
669-
class DerivedDimExample2(torch.nn.Module):
670-
def forward(self, z, y):
671-
return z[1:] + y[1::3]
672-
673-
foo = DerivedDimExample2()
674-
675-
z, y = torch.randn(4), torch.randn(10)
676-
dx = torch.export.Dim("dx", min=3, max=6)
677-
dz = dx + 1
678-
dy = dx * 3 + 1
679-
derived_dynamic_shapes2 = ({0: dz}, {0: dy})
680-
681-
derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
682-
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))
683-
684-
######################################################################
685-
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
686-
# are necessary. We can do this by relaxing all constraints (recall that if we
687-
# do not provide constraints for a dimension, the default behavior is to constrain
688-
# to the exact shape value of the example input) and letting ``torch.export``
689-
# error out.
690-
691-
inp4 = torch.randn(8, 16)
692-
inp5 = torch.randn(16, 32)
693-
694-
class DynamicShapesExample3(torch.nn.Module):
695-
def forward(self, x, y):
696-
if x.shape[0] <= 16:
697-
return x @ y[:, :16]
698-
return y
699-
700-
dynamic_shapes3 = {
701-
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
702-
"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
565+
w = x + y
566+
return w + torch.ones(4)
567+
568+
dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
569+
ep = export(
570+
Foo(),
571+
(torch.randn(6, 4), torch.randn(6, 4)),
572+
dynamic_shapes={
573+
"x": (dx, d1),
574+
"y": (dy, d1),
575+
},
576+
)
577+
578+
# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
579+
#
580+
# Lastly, there's couple nice-to-knows about the options for specification:
581+
# - ``None`` is a good option for static behavior:
582+
# - ``dynamic_shapes=None`` (default) exports with the entire model being static.
583+
# - specifying ``None`` at an input-level exports with all tensor dimensions static, and alternatively is also required for non-tensor inputs.
584+
# - specfiying ``None`` at a dimension-level specializes that dimension, though this is deprecated in favor of ``Dim.STATIC``.
585+
# - specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification.
586+
#
587+
# These options are combined in the inputs & dynamic shapes spec below:
588+
589+
inputs = (
590+
torch.randn(4, 4),
591+
torch.randn(3, 3),
592+
16,
593+
False,
594+
)
595+
dynamic_shapes = {
596+
"tensor_0": (Dim.AUTO, None),
597+
"tensor_1": None,
598+
"int_val": None,
599+
"bool_val": None,
703600
}
704601

705-
try:
706-
export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
707-
except Exception:
708-
tb.print_exc()
709-
710-
######################################################################
711-
# We can see that the error message gives us suggested fixes to our
712-
# dynamic shape constraints. Let us follow those suggestions (exact
713-
# suggestions may differ slightly):
714-
715-
def suggested_fixes():
716-
inp4_dim1 = Dim('shared_dim')
717-
# suggested fixes below
718-
inp4_dim0 = Dim('inp4_dim0', max=16)
719-
inp5_dim1 = Dim('inp5_dim1', min=17)
720-
inp5_dim0 = inp4_dim1
721-
# end of suggested fixes
722-
return {
723-
"x": {0: inp4_dim0, 1: inp4_dim1},
724-
"y": {0: inp5_dim0, 1: inp5_dim1},
725-
}
726-
727-
dynamic_shapes3_fixed = suggested_fixes()
728-
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
729-
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))
730-
731-
######################################################################
732-
# Note that in the example above, because we constrained the value of ``x.shape[0]`` in
733-
# ``dynamic_shapes_example3``, the exported program is sound even though there is a
734-
# raw ``if`` statement.
735-
#
736-
# If you want to see why ``torch.export`` generated these constraints, you can
737-
# re-run the script with the environment variable ``TORCH_LOGS=dynamic,dynamo``,
738-
# or use ``torch._logging.set_logs``.
739-
740-
import logging
741-
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
742-
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
743-
744-
# reset to previous values
745-
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)
746-
747-
######################################################################
748-
# We can view an ``ExportedProgram``'s symbolic shape ranges using the
749-
# ``range_constraints`` field.
750-
751-
print(exported_dynamic_shapes_example3.range_constraints)
752-
753602
######################################################################
754603
# Custom Ops
755604
# ----------

0 commit comments

Comments
 (0)