Skip to content

Commit 6dc1122

Browse files
committed
init
1 parent a91f631 commit 6dc1122

File tree

1 file changed

+215
-1
lines changed

1 file changed

+215
-1
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 215 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
torch.export Tutorial
55
===================================================
6-
**Author:** William Wen, Zhengxu Chen, Angela Yi
6+
**Author:** William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan
77
"""
88

99
######################################################################
@@ -304,6 +304,220 @@ def false_fn(x):
304304
# Constraints/Dynamic Shapes
305305
# --------------------------
306306
#
307+
# This section covers dynamic behavior and representation of exported programs. Dynamic behavior is
308+
# very subjective to the particular model being exported, so for purposes of this tutorial, we'll focus
309+
# on this particular toy model (with the sample input shapes annotated):
310+
311+
class DynamicModel(torch.nn.Module):
312+
def __init__(self):
313+
super().__init__()
314+
self.l = torch.nn.Linear(5, 3)
315+
316+
def forward(
317+
self,
318+
w: torch.Tensor, # [6, 5]
319+
x: torch.Tensor, # [4]
320+
y: torch.Tensor, # [8, 4]
321+
z: torch.Tensor, # [32]
322+
):
323+
x0 = x + y # output shape: [8, 4]
324+
x1 = self.l(w) # [6, 3]
325+
x2 = x0.flatten() # [32]
326+
x3 = x2 + z # [32]
327+
return x1, x3
328+
329+
# By default, ``torch.export`` produces a static program. One clear consequence of this is that at runtime,
330+
# the program won't work on inputs with different shapes, even if they're valid in eager mode.
331+
332+
w = torch.randn(6, 5)
333+
x = torch.randn(4)
334+
y = torch.randn(8, 4)
335+
z = torch.randn(32)
336+
model = DynamicModel()
337+
ep = export(model, (w, x, y, z))
338+
model(w, x, torch.randn(3, 4), torch.randn(12))
339+
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
340+
341+
# To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with
342+
# dynamic shapes is to use ``Dim.AUTO`` and look at the program that's returned. Dynamic behavior is specified
343+
# at a input dimension-level; for each input we can specify a tuple of values:
344+
345+
from torch.export.dynamic_shapes import Dim
346+
347+
dynamic_shapes = {
348+
"w": (Dim.AUTO, Dim.AUTO),
349+
"x": (Dim.AUTO,),
350+
"y": (Dim.AUTO, Dim.AUTO),
351+
"z": (Dim.AUTO,),
352+
}
353+
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
354+
355+
# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails,
356+
# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is
357+
# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or [1, inf]``? we'll explain later in the
358+
# 0/1 specialization section).
359+
#
360+
# Export then runs model tracing, looking at each operation that's performed by the model. Each individual can emit
361+
# what's called a "guard"; basically a boolean condition that's required to be true for this program to be valid.
362+
# When these guards involve the symbols allocated for the input dimensions, our program now contains restrictions on
363+
# what input shapes are valid; i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible
364+
# for taking in all the emitted guards and producing a final program representation that adheres to all of these guards.
365+
# Before we see this "final representation", let's look at the guards emitted by the toy model we're tracing.
366+
#
367+
# Here, each input tensor is annotated with the symbol allocated at the start of tracing:
368+
369+
class DynamicModel(torch.nn.Module):
370+
def __init__(self):
371+
super().__init__()
372+
self.l = torch.nn.Linear(5, 3)
373+
374+
def forward(
375+
self,
376+
w: torch.Tensor, # [s0, s1]
377+
x: torch.Tensor, # [s2]
378+
y: torch.Tensor, # [s3, s4]
379+
z: torch.Tensor, # [s5]
380+
):
381+
x0 = x + y # guard: s2 == s4
382+
x1 = self.l(w) # guard: s1 == 5
383+
x2 = x0.flatten()
384+
x3 = x2 + z # guard: s3 * s4 == s5
385+
return x1, x3
386+
387+
# Let's understand each of the operations and the emitted guards:
388+
#
389+
# - ``x0 = x + y``: This is an elementwise-add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor.
390+
# ``x`` is broadcasted to match the last dimension ``y``, emitting the guard ``s2 == s4``.
391+
# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export,
392+
# parameters, buffers, and constants are considered program state, which we require to be static, and therefore this is
393+
# a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``.
394+
# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes)
395+
# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this elementwise-add emits ``s3 * s4 == s5``.
396+
#
397+
# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes
398+
# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid:
399+
#
400+
# ``w: [s0, 5]``
401+
# ``x: [s2]``
402+
# ``y: [s3, s2]``
403+
# ``z: [s2*s3]``
404+
#
405+
# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the
406+
# corresponding inputs!
407+
408+
print(ep)
409+
410+
# Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn't
411+
# so interesting currently, since this export call doesn't emit any guards related to symbol bounds and each base symbol has
412+
# a generic bound, but this will come up later.
413+
#
414+
# So far, because we've been exporting this toy model, this experience has been misrepresentative of how hard
415+
# it typically is to debug dynamic shapes guards & issues. In most cases it isn't obvious what guards are being emitted,
416+
# and which operations and lines of user code are responsible. For this toy model we pinpoint the exact lines, and the guards
417+
# are rather intuitive.
418+
#
419+
# In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment
420+
# variable ``TORCH_LOGS="+dynamic"``, or interactively with ``torch._logging.set_logs(dynamic=10)``:
421+
422+
torch._logging.set_logs(dynamic=10)
423+
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
424+
425+
# This spits out quite a handful, even with this simple toy model. But looking through the logs we can see the lines relevant
426+
# to what we described above; e.g. the allocation of symbols:
427+
428+
"""
429+
I1210 16:20:19.720000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
430+
I1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
431+
V1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:6535] [1/0] runtime_assert True == True [statically known]
432+
I1210 16:20:19.727000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
433+
I1210 16:20:19.729000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
434+
I1210 16:20:19.731000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
435+
I1210 16:20:19.734000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
436+
"""
437+
438+
# Or the guards emitted:
439+
440+
"""
441+
I1210 16:20:19.743000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
442+
I1210 16:20:19.754000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
443+
I1210 16:20:19.775000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
444+
"""
445+
446+
# Next to the ``[guard added]`` messages, we also see the responsible user lines of code - luckily here the model is simple enough.
447+
# In many real-world cases it's not so straightforward: high-level torch operations can have complicated fake-kernel implementations
448+
# or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate
449+
# is to follow the logs' suggestion, and re-run with environment variable ``TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."``, to further
450+
# attribute the guard of interest.
451+
#
452+
# ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available:
453+
# ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all
454+
# ways except one: it raises an error when specializing to a constant; designed to maintain dynamism. See for example what happens when a
455+
# static guard is emitted on a dynamically-marked dimension:
456+
457+
dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
458+
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
459+
460+
# Static guards also aren't always inherent to the model; they can also come from user-specifications. In fact, a common pitfall leading to shape
461+
# specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is
462+
# raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``:
463+
464+
dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
465+
dynamic_shapes["x"] = (Dim.STATIC,)
466+
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
467+
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
468+
469+
# Here you might ask why export "specializes"; why we resolve this static/dynamic conflict by going with the static route. The answer is because
470+
# of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile
471+
# treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to
472+
# specialization.
473+
#
474+
# One feature of export is that during tracing, statements like asserts, ``torch._checks()``, and if/else conditions will also emit guards.
475+
# See what happens when we augment the existing model with such statements:
476+
477+
class DynamicModel(torch.nn.Module):
478+
def __init__(self):
479+
super().__init__()
480+
self.l = torch.nn.Linear(5, 3)
481+
482+
def forward(self, w, x, y, z):
483+
assert w.shape[0] <= 512
484+
torch._check(x.shape[0] >= 16)
485+
if w.shape[0] == x.shape[0] + 2:
486+
x0 = x + y
487+
x1 = self.l(w)
488+
x2 = x0.flatten()
489+
x3 = x2 + z
490+
return x1, x3
491+
else:
492+
return w
493+
494+
dynamic_shapes = {
495+
"w": (Dim.AUTO, Dim.AUTO),
496+
"x": (Dim.AUTO,),
497+
"y": (Dim.AUTO, Dim.AUTO),
498+
"z": (Dim.AUTO,),
499+
}
500+
ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
501+
print(ep)
502+
503+
# Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``,
504+
# and ``s2`` now contains lower and upper bounds, reflected in ``range_constraints``.
505+
#
506+
# For the if/else condition, you might ask why the True branch was taken, and why it wasn't the ``w.shape[0] != x.shape[0] + 2`` guard that
507+
# got emitted from tracing. The answer is export is guided by the sample inputs provided by tracing, and specializes on the branches taken.
508+
# If different sample input shapes were provided that fail the if condition, export would trace and emit guards corresponding to the else branch.
509+
# Additionally, you might ask why we traced only the if branch, and if it's possible to maintain control-flow in your program and keep both branches
510+
# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above.
511+
#
512+
# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier.
513+
#
514+
515+
516+
517+
518+
519+
520+
307521
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
308522
# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
309523
# example inputs given to the initial ``torch.export.export()`` call.

0 commit comments

Comments
 (0)