Skip to content

Commit 0543076

Browse files
committed
lint
1 parent d732d69 commit 0543076

File tree

1 file changed

+33
-18
lines changed

1 file changed

+33
-18
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def false_fn(x):
305305
# --------------------------
306306
#
307307
# This section covers dynamic behavior and representation of exported programs. Dynamic behavior is
308-
# very subjective to the particular model being exported, so for purposes of this tutorial, we'll focus
308+
# subjective to the particular model being exported, so for the most part of this tutorial, we'll focus
309309
# on this particular toy model (with the sample input shapes annotated):
310310

311311
class DynamicModel(torch.nn.Module):
@@ -326,6 +326,7 @@ def forward(
326326
x3 = x2 + z # [32]
327327
return x1, x3
328328

329+
######################################################################
329330
# By default, ``torch.export`` produces a static program. One clear consequence of this is that at runtime,
330331
# the program won't work on inputs with different shapes, even if they're valid in eager mode.
331332

@@ -338,6 +339,7 @@ def forward(
338339
model(w, x, torch.randn(3, 4), torch.randn(12))
339340
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
340341

342+
######################################################################
341343
# To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with
342344
# dynamic shapes is to use ``Dim.AUTO`` and look at the program that's returned. Dynamic behavior is specified
343345
# at a input dimension-level; for each input we can specify a tuple of values:
@@ -352,19 +354,20 @@ def forward(
352354
}
353355
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
354356

357+
######################################################################
355358
# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails,
356359
# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is
357360
# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or [1, inf]``? we'll explain later in the
358361
# 0/1 specialization section).
359362
#
360-
# Export then runs model tracing, looking at each operation that's performed by the model. Each individual can emit
361-
# what's called a "guard"; basically a boolean condition that's required to be true for this program to be valid.
362-
# When these guards involve the symbols allocated for the input dimensions, our program now contains restrictions on
363-
# what input shapes are valid; i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible
364-
# for taking in all the emitted guards and producing a final program representation that adheres to all of these guards.
365-
# Before we see this "final representation", let's look at the guards emitted by the toy model we're tracing.
363+
# Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit
364+
# what's called "guards"; basically boolean condition that are required to be true for the program to be valid.
365+
# When guards involve symbols allocated for input dimensions, the program contains restrictions on what input shapes are valid;
366+
# i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible for taking in all the emitted guards
367+
# and producing a final program representation that adheres to all of these guards. Before we see this "final representation" in
368+
# an ExportedProgram, let's look at the guards emitted by the toy model we're tracing.
366369
#
367-
# Here, each input tensor is annotated with the symbol allocated at the start of tracing:
370+
# Here, each forward input tensor is annotated with the symbol allocated at the start of tracing:
368371

369372
class DynamicModel(torch.nn.Module):
370373
def __init__(self):
@@ -384,15 +387,16 @@ def forward(
384387
x3 = x2 + z # guard: s3 * s4 == s5
385388
return x1, x3
386389

390+
######################################################################
387391
# Let's understand each of the operations and the emitted guards:
388392
#
389-
# - ``x0 = x + y``: This is an elementwise-add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor.
390-
# ``x`` is broadcasted to match the last dimension ``y``, emitting the guard ``s2 == s4``.
393+
# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor.
394+
# ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``.
391395
# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export,
392-
# parameters, buffers, and constants are considered program state, which we require to be static, and therefore this is
396+
# parameters, buffers, and constants are considered program state, which is considered static, and so this is
393397
# a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``.
394398
# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes)
395-
# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this elementwise-add emits ``s3 * s4 == s5``.
399+
# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this element-wise add emits ``s3 * s4 == s5``.
396400
#
397401
# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes
398402
# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid:
@@ -403,17 +407,18 @@ def forward(
403407
# ``z: [s2*s3]``
404408
#
405409
# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the
406-
# corresponding inputs!
410+
# corresponding inputs:
407411

408412
print(ep)
409413

414+
######################################################################
410415
# Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn't
411416
# so interesting currently, since this export call doesn't emit any guards related to symbol bounds and each base symbol has
412417
# a generic bound, but this will come up later.
413418
#
414419
# So far, because we've been exporting this toy model, this experience has been misrepresentative of how hard
415420
# it typically is to debug dynamic shapes guards & issues. In most cases it isn't obvious what guards are being emitted,
416-
# and which operations and lines of user code are responsible. For this toy model we pinpoint the exact lines, and the guards
421+
# and which operations and parts of user code are responsible. For this toy model we pinpoint the exact lines, and the guards
417422
# are rather intuitive.
418423
#
419424
# In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment
@@ -422,6 +427,7 @@ def forward(
422427
torch._logging.set_logs(dynamic=10)
423428
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
424429

430+
######################################################################
425431
# This spits out quite a handful, even with this simple toy model. But looking through the logs we can see the lines relevant
426432
# to what we described above; e.g. the allocation of symbols:
427433

@@ -435,6 +441,7 @@ def forward(
435441
I1210 16:20:19.734000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
436442
"""
437443

444+
######################################################################
438445
# Or the guards emitted:
439446

440447
"""
@@ -443,6 +450,7 @@ def forward(
443450
I1210 16:20:19.775000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
444451
"""
445452

453+
######################################################################
446454
# Next to the ``[guard added]`` messages, we also see the responsible user lines of code - luckily here the model is simple enough.
447455
# In many real-world cases it's not so straightforward: high-level torch operations can have complicated fake-kernel implementations
448456
# or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate
@@ -457,6 +465,7 @@ def forward(
457465
dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
458466
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
459467

468+
######################################################################
460469
# Static guards also aren't always inherent to the model; they can also come from user-specifications. In fact, a common pitfall leading to shape
461470
# specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is
462471
# raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``:
@@ -466,12 +475,13 @@ def forward(
466475
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
467476
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
468477

478+
######################################################################
469479
# Here you might ask why export "specializes"; why we resolve this static/dynamic conflict by going with the static route. The answer is because
470480
# of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile
471481
# treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to
472482
# specialization.
473483
#
474-
# One feature of export is that during tracing, statements like asserts, ``torch._checks()``, and if/else conditions will also emit guards.
484+
# One feature of export is that during tracing, statements like asserts, ``torch._checks()``, and ``if/else`` conditions will also emit guards.
475485
# See what happens when we augment the existing model with such statements:
476486

477487
class DynamicModel(torch.nn.Module):
@@ -500,13 +510,14 @@ def forward(self, w, x, y, z):
500510
ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
501511
print(ep)
502512

513+
######################################################################
503514
# Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``,
504515
# and ``s2`` now contains lower and upper bounds, reflected in ``range_constraints``.
505516
#
506517
# For the if/else condition, you might ask why the True branch was taken, and why it wasn't the ``w.shape[0] != x.shape[0] + 2`` guard that
507-
# got emitted from tracing. The answer is export is guided by the sample inputs provided by tracing, and specializes on the branches taken.
508-
# If different sample input shapes were provided that fail the if condition, export would trace and emit guards corresponding to the else branch.
509-
# Additionally, you might ask why we traced only the if branch, and if it's possible to maintain control-flow in your program and keep both branches
518+
# got emitted from tracing. The answer is that export is guided by the sample inputs provided by tracing, and specializes on the branches taken.
519+
# If different sample input shapes were provided that fail the ``if`` condition, export would trace and emit guards corresponding to the ``else`` branch.
520+
# Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches
510521
# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above.
511522
#
512523
# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier.
@@ -524,6 +535,7 @@ def forward(self, w, x, y, z):
524535
)
525536
ep.module()(torch.randn(2, 4))
526537

538+
######################################################################
527539
# So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the
528540
# low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic
529541
# dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a "discovery" process, summarizing these guards
@@ -541,6 +553,7 @@ def forward(self, w, x, y, z):
541553
"y": (2 * dx, dh),
542554
}
543555

556+
######################################################################
544557
# This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the
545558
# dynamic behavior of the ExportedProgram produced; ConstraintViolation errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic
546559
# specifications given. For example, in the above specification, the following is asserted:
@@ -556,6 +569,7 @@ def forward(self, w, x, y, z):
556569
"x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4.
557570
}
558571

572+
######################################################################
559573
# One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing.
560574
# That would lead to ConstraintViolation errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between
561575
# dimensions 0 of ``x`` and ``y``, and requires dimension 1 to be static.
@@ -575,6 +589,7 @@ def forward(self, x, y):
575589
},
576590
)
577591

592+
######################################################################
578593
# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
579594
#
580595
# Lastly, there's couple nice-to-knows about the options for specification:

0 commit comments

Comments
 (0)