-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Patched docs for torch_compile_tutorial #2936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
204f9fc
e580c60
8251141
f4bb0fd
a5d38cf
288748f
02c90d7
fefbe1f
f36b152
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,17 +73,35 @@ def foo(x, y): | |
|
||
###################################################################### | ||
# Alternatively, we can decorate the function. | ||
t1 = torch.randn(10, 10) | ||
t2 = torch.randn(10, 10) | ||
|
||
@torch.compile | ||
def opt_foo2(x, y): | ||
a = torch.sin(x) | ||
b = torch.cos(y) | ||
return a + b | ||
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10))) | ||
print(opt_foo2(t1, t2)) | ||
|
||
# When using the decorator approach, nested function calls within the decorated | ||
# function will also be compiled. | ||
|
||
def nested_function(x): | ||
return torch.sin(x) | ||
|
||
@torch.compile | ||
def outer_function(x, y): | ||
a = nested_function(x) | ||
b = torch.cos(y) | ||
return a + b | ||
|
||
print(outer_function(t1, t2)) | ||
|
||
###################################################################### | ||
# We can also optimize ``torch.nn.Module`` instances. | ||
|
||
t = torch.randn(10, 100) | ||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
@@ -94,7 +112,74 @@ def forward(self, x): | |
|
||
mod = MyModule() | ||
opt_mod = torch.compile(mod) | ||
print(opt_mod(torch.randn(10, 100))) | ||
print(opt_mod(t)) | ||
|
||
# In the same fashion, when compiling a module all sub-modules and methods | ||
# within it are also compiled. | ||
|
||
class OuterModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.inner_module = MyModule() | ||
self.outer_lin = torch.nn.Linear(10, 2) | ||
|
||
def forward(self, x): | ||
x = self.inner_module(x) | ||
return torch.nn.functional.relu(self.outer_lin(x)) | ||
|
||
outer_mod = OuterModule() | ||
opt_outer_mod = torch.compile(outer_mod) | ||
print(opt_outer_mod(t)) | ||
|
||
###################################################################### | ||
# We can also disable some functions from being compiled by using | ||
williamwen42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# `torch.compiler.disable` | ||
|
||
@torch.compiler.disable | ||
def complex_function(real, imag): | ||
# Assuming this function cause problems in the compilation | ||
return torch.complex(real, imag) | ||
|
||
def outer_function(): | ||
real = torch.tensor([2, 3], dtype=torch.float32) | ||
imag = torch.tensor([4, 5], dtype=torch.float32) | ||
z = complex_function(real, imag) | ||
return torch.abs(z) | ||
|
||
# Try to compile the outer_function | ||
try: | ||
opt_outer_function = torch.compile(outer_function) | ||
print(opt_outer_function()) | ||
except Exception as e: | ||
print("Compilation of outer_function failed:", e) | ||
|
||
###################################################################### | ||
# Best Practices and Recommendations | ||
# ---------------------------------- | ||
# | ||
# Behavior of ``torch.compile`` with Nested Modules and Function Calls | ||
# | ||
# When you use ``torch.compile``, the compiler will try to recursively inline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't have to mention inlining - going over how dynamo inlines is a little more involved. |
||
# and compile every function call inside the target function or module. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "every function call" is not exactly right - there is a skiplist of various functions (e.g. builtins, some functions in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, so I will modify the wording adding your comment to something like "compile every function call inside the target function or module that is not in a skiplist (e.g. builtins, some functions in the torch.* namespace)." |
||
# | ||
# This includes: | ||
# | ||
# - **Nested function calls:** All functions called within the decorated or compiled function will also be compiled. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There isn't really too much of a difference in the nested call behavior between torch.compile'd functions and torch.compile'd modules, so I don't think we need to highlight them as distinct cases. |
||
# | ||
# - **Nested modules:** If a ``torch.nn.Module`` is compiled, all sub-modules and functions within the module are also compiled. | ||
# | ||
# **Best Practices:** | ||
williamwen42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# 1. **Modular Testing:** Test individual functions and modules with ``torch.compile`` | ||
# before integrating them into larger models to isolate potential issues. | ||
# | ||
# 2. **Disable Compilation Selectively:** If certain functions or sub-modules | ||
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context | ||
# managers to recursively exclude them from compilation. | ||
# | ||
# 3. **Compile Leaf Functions First:** In complex models with multiple nested | ||
# functions and modules, start by compiling the leaf functions or modules first. | ||
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__. | ||
|
||
###################################################################### | ||
# Demonstrating Speedups | ||
|
Uh oh!
There was an error while loading. Please reload this page.