You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Let's understand each of the operations and the emitted guards:
392
392
#
393
-
# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor.
394
-
# ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``.
395
-
# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export,
396
-
# parameters, buffers, and constants are considered program state, which is considered static, and so this is
397
-
# a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``.
393
+
# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor. ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``.
394
+
# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``.
398
395
# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes)
399
396
# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this element-wise add emits ``s3 * s4 == s5``.
400
397
#
401
398
# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes
402
399
# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid:
403
-
#
404
-
# ``w: [s0, 5]``
405
-
# ``x: [s2]``
406
-
# ``y: [s3, s2]``
407
-
# ``z: [s2*s3]``
400
+
#
401
+
# - ``w: [s0, 5]``
402
+
# - ``x: [s2]``
403
+
# - ``y: [s3, s2]``
404
+
# - ``z: [s2*s3]``
408
405
#
409
406
# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the
410
407
# corresponding inputs:
@@ -557,6 +554,7 @@ def forward(self, w, x, y, z):
557
554
# This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the
558
555
# dynamic behavior of the ``ExportedProgram`` produced; ``ConstraintViolation`` errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic
559
556
# specifications given. For example, in the above specification, the following is asserted:
557
+
#
560
558
# - ``x.shape[0]`` is to have range ``[4, 256]``, and related to ``y.shape[0]`` by ``y.shape[0] == 2 * x.shape[0]``.
561
559
# - ``x.shape[1]`` is static.
562
560
# - ``y.shape[1]`` has range ``[2, 512]``, and is unrelated to any other dimension.
@@ -593,6 +591,7 @@ def forward(self, x, y):
593
591
# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
594
592
#
595
593
# Lastly, there's couple nice-to-knows about the options for specification:
594
+
#
596
595
# - ``None`` is a good option for static behavior:
597
596
# - ``dynamic_shapes=None`` (default) exports with the entire model being static.
598
597
# - specifying ``None`` at an input-level exports with all tensor dimensions static, and alternatively is also required for non-tensor inputs.
0 commit comments