Skip to content

Commit 5b1689e

Browse files
committed
init
1 parent ab2aafd commit 5b1689e

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def forward(self, w, x, y, z):
489489
# specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens
490490
# at runtime when we export this linear layer:
491491

492+
torch._logging.set_logs(dynamic=0)
492493
ep = export(
493494
torch.nn.Linear(4, 3),
494495
(torch.randn(1, 4),),
@@ -591,6 +592,30 @@ def forward(self, x, y):
591592
"bool_val": None,
592593
}
593594

595+
######################################################################
596+
# (experimental) Avoiding 0/1 specialization
597+
# ^^^^^^^^^^^^^^^^^^
598+
#
599+
# Export provides an experimental option to avoid specializing on size 0/1 sample inputs. Users can turn on `torch.fx.experimental._config.backed_size_oblivious = True` to enable this behavior.
600+
# This allows the compiler to allocate a [0, inf] range for symbols, and assume general-case semantics in compiler decisions between semantics for size 0/1 and >= 2 sizes.
601+
# This can lead to behavior divergence between eager mode and the exported program on size 0/1 inputs - for example, in broadcasting decisions, we will assume input shapes are not 1-specialized,
602+
# and therefore assume broadcasting does not apply (even if it does on the particular sample inputs). The same logic applies for other semantics (e.g. contiguity), and size 0 tensors.
603+
#
604+
# The exact semantics under this flag are a work in progress, and usage is recommended only when the user is certain their model does not rely on 0/1-specialized semantics.
605+
# For now, export users can enable this with:
606+
607+
class Foo(torch.nn.Module):
608+
def forward(self, x, y):
609+
return x + y # nothing special about size 0/1 here
610+
611+
x = torch.randn(0, 1)
612+
y = torch.randn(1)
613+
dynamic_shapes = {"x": (Dim.AUTO, Dim.AUTO), "y": (Dim.AUTO,)}
614+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
615+
ep = export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
616+
ep.module()(torch.randn(8, 1), torch.randn(1))
617+
ep.module()(torch.randn(5, 6), torch.randn(6))
618+
594619
######################################################################
595620
# Data-dependent errors
596621
# ---------------------

0 commit comments

Comments
 (0)