|
3 | 3 | """
|
4 | 4 | torch.export Tutorial
|
5 | 5 | ===================================================
|
6 |
| -**Author:** William Wen, Zhengxu Chen, Angela Yi |
| 6 | +**Author:** William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan |
7 | 7 | """
|
8 | 8 |
|
9 | 9 | ######################################################################
|
@@ -304,6 +304,220 @@ def false_fn(x):
|
304 | 304 | # Constraints/Dynamic Shapes
|
305 | 305 | # --------------------------
|
306 | 306 | #
|
| 307 | +# This section covers dynamic behavior and representation of exported programs. Dynamic behavior is |
| 308 | +# very subjective to the particular model being exported, so for purposes of this tutorial, we'll focus |
| 309 | +# on this particular toy model (with the sample input shapes annotated): |
| 310 | + |
| 311 | +class DynamicModel(torch.nn.Module): |
| 312 | + def __init__(self): |
| 313 | + super().__init__() |
| 314 | + self.l = torch.nn.Linear(5, 3) |
| 315 | + |
| 316 | + def forward( |
| 317 | + self, |
| 318 | + w: torch.Tensor, # [6, 5] |
| 319 | + x: torch.Tensor, # [4] |
| 320 | + y: torch.Tensor, # [8, 4] |
| 321 | + z: torch.Tensor, # [32] |
| 322 | + ): |
| 323 | + x0 = x + y # output shape: [8, 4] |
| 324 | + x1 = self.l(w) # [6, 3] |
| 325 | + x2 = x0.flatten() # [32] |
| 326 | + x3 = x2 + z # [32] |
| 327 | + return x1, x3 |
| 328 | + |
| 329 | +# By default, ``torch.export`` produces a static program. One clear consequence of this is that at runtime, |
| 330 | +# the program won't work on inputs with different shapes, even if they're valid in eager mode. |
| 331 | + |
| 332 | +w = torch.randn(6, 5) |
| 333 | +x = torch.randn(4) |
| 334 | +y = torch.randn(8, 4) |
| 335 | +z = torch.randn(32) |
| 336 | +model = DynamicModel() |
| 337 | +ep = export(model, (w, x, y, z)) |
| 338 | +model(w, x, torch.randn(3, 4), torch.randn(12)) |
| 339 | +ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) |
| 340 | + |
| 341 | +# To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with |
| 342 | +# dynamic shapes is to use ``Dim.AUTO`` and look at the program that's returned. Dynamic behavior is specified |
| 343 | +# at a input dimension-level; for each input we can specify a tuple of values: |
| 344 | + |
| 345 | +from torch.export.dynamic_shapes import Dim |
| 346 | + |
| 347 | +dynamic_shapes = { |
| 348 | + "w": (Dim.AUTO, Dim.AUTO), |
| 349 | + "x": (Dim.AUTO,), |
| 350 | + "y": (Dim.AUTO, Dim.AUTO), |
| 351 | + "z": (Dim.AUTO,), |
| 352 | +} |
| 353 | +ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
| 354 | + |
| 355 | +# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails, |
| 356 | +# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is |
| 357 | +# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or [1, inf]``? we'll explain later in the |
| 358 | +# 0/1 specialization section). |
| 359 | +# |
| 360 | +# Export then runs model tracing, looking at each operation that's performed by the model. Each individual can emit |
| 361 | +# what's called a "guard"; basically a boolean condition that's required to be true for this program to be valid. |
| 362 | +# When these guards involve the symbols allocated for the input dimensions, our program now contains restrictions on |
| 363 | +# what input shapes are valid; i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible |
| 364 | +# for taking in all the emitted guards and producing a final program representation that adheres to all of these guards. |
| 365 | +# Before we see this "final representation", let's look at the guards emitted by the toy model we're tracing. |
| 366 | +# |
| 367 | +# Here, each input tensor is annotated with the symbol allocated at the start of tracing: |
| 368 | + |
| 369 | +class DynamicModel(torch.nn.Module): |
| 370 | + def __init__(self): |
| 371 | + super().__init__() |
| 372 | + self.l = torch.nn.Linear(5, 3) |
| 373 | + |
| 374 | + def forward( |
| 375 | + self, |
| 376 | + w: torch.Tensor, # [s0, s1] |
| 377 | + x: torch.Tensor, # [s2] |
| 378 | + y: torch.Tensor, # [s3, s4] |
| 379 | + z: torch.Tensor, # [s5] |
| 380 | + ): |
| 381 | + x0 = x + y # guard: s2 == s4 |
| 382 | + x1 = self.l(w) # guard: s1 == 5 |
| 383 | + x2 = x0.flatten() |
| 384 | + x3 = x2 + z # guard: s3 * s4 == s5 |
| 385 | + return x1, x3 |
| 386 | + |
| 387 | +# Let's understand each of the operations and the emitted guards: |
| 388 | +# |
| 389 | +# - ``x0 = x + y``: This is an elementwise-add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor. |
| 390 | +# ``x`` is broadcasted to match the last dimension ``y``, emitting the guard ``s2 == s4``. |
| 391 | +# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export, |
| 392 | +# parameters, buffers, and constants are considered program state, which we require to be static, and therefore this is |
| 393 | +# a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``. |
| 394 | +# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes) |
| 395 | +# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this elementwise-add emits ``s3 * s4 == s5``. |
| 396 | +# |
| 397 | +# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes |
| 398 | +# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid: |
| 399 | +# |
| 400 | +# ``w: [s0, 5]`` |
| 401 | +# ``x: [s2]`` |
| 402 | +# ``y: [s3, s2]`` |
| 403 | +# ``z: [s2*s3]`` |
| 404 | +# |
| 405 | +# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the |
| 406 | +# corresponding inputs! |
| 407 | + |
| 408 | +print(ep) |
| 409 | + |
| 410 | +# Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn't |
| 411 | +# so interesting currently, since this export call doesn't emit any guards related to symbol bounds and each base symbol has |
| 412 | +# a generic bound, but this will come up later. |
| 413 | +# |
| 414 | +# So far, because we've been exporting this toy model, this experience has been misrepresentative of how hard |
| 415 | +# it typically is to debug dynamic shapes guards & issues. In most cases it isn't obvious what guards are being emitted, |
| 416 | +# and which operations and lines of user code are responsible. For this toy model we pinpoint the exact lines, and the guards |
| 417 | +# are rather intuitive. |
| 418 | +# |
| 419 | +# In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment |
| 420 | +# variable ``TORCH_LOGS="+dynamic"``, or interactively with ``torch._logging.set_logs(dynamic=10)``: |
| 421 | + |
| 422 | +torch._logging.set_logs(dynamic=10) |
| 423 | +ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
| 424 | + |
| 425 | +# This spits out quite a handful, even with this simple toy model. But looking through the logs we can see the lines relevant |
| 426 | +# to what we described above; e.g. the allocation of symbols: |
| 427 | + |
| 428 | +""" |
| 429 | +I1210 16:20:19.720000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" |
| 430 | +I1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" |
| 431 | +V1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:6535] [1/0] runtime_assert True == True [statically known] |
| 432 | +I1210 16:20:19.727000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" |
| 433 | +I1210 16:20:19.729000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" |
| 434 | +I1210 16:20:19.731000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" |
| 435 | +I1210 16:20:19.734000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" |
| 436 | +""" |
| 437 | + |
| 438 | +# Or the guards emitted: |
| 439 | + |
| 440 | +""" |
| 441 | +I1210 16:20:19.743000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" |
| 442 | +I1210 16:20:19.754000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" |
| 443 | +I1210 16:20:19.775000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" |
| 444 | +""" |
| 445 | + |
| 446 | +# Next to the ``[guard added]`` messages, we also see the responsible user lines of code - luckily here the model is simple enough. |
| 447 | +# In many real-world cases it's not so straightforward: high-level torch operations can have complicated fake-kernel implementations |
| 448 | +# or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate |
| 449 | +# is to follow the logs' suggestion, and re-run with environment variable ``TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."``, to further |
| 450 | +# attribute the guard of interest. |
| 451 | +# |
| 452 | +# ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available: |
| 453 | +# ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all |
| 454 | +# ways except one: it raises an error when specializing to a constant; designed to maintain dynamism. See for example what happens when a |
| 455 | +# static guard is emitted on a dynamically-marked dimension: |
| 456 | + |
| 457 | +dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC) |
| 458 | +export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
| 459 | + |
| 460 | +# Static guards also aren't always inherent to the model; they can also come from user-specifications. In fact, a common pitfall leading to shape |
| 461 | +# specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is |
| 462 | +# raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``: |
| 463 | + |
| 464 | +dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO) |
| 465 | +dynamic_shapes["x"] = (Dim.STATIC,) |
| 466 | +dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC) |
| 467 | +export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
| 468 | + |
| 469 | +# Here you might ask why export "specializes"; why we resolve this static/dynamic conflict by going with the static route. The answer is because |
| 470 | +# of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile |
| 471 | +# treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to |
| 472 | +# specialization. |
| 473 | +# |
| 474 | +# One feature of export is that during tracing, statements like asserts, ``torch._checks()``, and if/else conditions will also emit guards. |
| 475 | +# See what happens when we augment the existing model with such statements: |
| 476 | + |
| 477 | +class DynamicModel(torch.nn.Module): |
| 478 | + def __init__(self): |
| 479 | + super().__init__() |
| 480 | + self.l = torch.nn.Linear(5, 3) |
| 481 | + |
| 482 | + def forward(self, w, x, y, z): |
| 483 | + assert w.shape[0] <= 512 |
| 484 | + torch._check(x.shape[0] >= 16) |
| 485 | + if w.shape[0] == x.shape[0] + 2: |
| 486 | + x0 = x + y |
| 487 | + x1 = self.l(w) |
| 488 | + x2 = x0.flatten() |
| 489 | + x3 = x2 + z |
| 490 | + return x1, x3 |
| 491 | + else: |
| 492 | + return w |
| 493 | + |
| 494 | +dynamic_shapes = { |
| 495 | + "w": (Dim.AUTO, Dim.AUTO), |
| 496 | + "x": (Dim.AUTO,), |
| 497 | + "y": (Dim.AUTO, Dim.AUTO), |
| 498 | + "z": (Dim.AUTO,), |
| 499 | +} |
| 500 | +ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes) |
| 501 | +print(ep) |
| 502 | + |
| 503 | +# Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``, |
| 504 | +# and ``s2`` now contains lower and upper bounds, reflected in ``range_constraints``. |
| 505 | +# |
| 506 | +# For the if/else condition, you might ask why the True branch was taken, and why it wasn't the ``w.shape[0] != x.shape[0] + 2`` guard that |
| 507 | +# got emitted from tracing. The answer is export is guided by the sample inputs provided by tracing, and specializes on the branches taken. |
| 508 | +# If different sample input shapes were provided that fail the if condition, export would trace and emit guards corresponding to the else branch. |
| 509 | +# Additionally, you might ask why we traced only the if branch, and if it's possible to maintain control-flow in your program and keep both branches |
| 510 | +# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above. |
| 511 | +# |
| 512 | +# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier. |
| 513 | +# |
| 514 | + |
| 515 | + |
| 516 | + |
| 517 | + |
| 518 | + |
| 519 | + |
| 520 | + |
307 | 521 | # Ops can have different specializations/behaviors for different tensor shapes, so by default,
|
308 | 522 | # ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
|
309 | 523 | # example inputs given to the initial ``torch.export.export()`` call.
|
|
0 commit comments