-
Notifications
You must be signed in to change notification settings - Fork 4.2k
@torch.compile
some tutorials
#2984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e577bac
d8e6e12
3e557e7
c6d308e
da987e5
344861d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ | |
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch._dynamo import config | ||
config.inline_inbuilt_nn_modules = 1 | ||
torch.manual_seed(0) | ||
|
||
# Here's a simple MLP | ||
|
@@ -50,7 +52,7 @@ def forward(self, x): | |
# minibatch of size 64. Furthermore, lets say we want to combine the predictions | ||
# from 10 different models. | ||
|
||
device = 'cuda' | ||
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' | ||
num_models = 10 | ||
|
||
data = torch.randn(100, 64, 1, 28, 28, device=device) | ||
|
@@ -125,7 +127,11 @@ def fmodel(params, buffers, x): | |
|
||
from torch import vmap | ||
|
||
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) | ||
@torch.compile | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Main comment is: all of these tutorials should have a separate section at the end that says "let's try to use torch.compile, and here are the speedups". |
||
def compute_predictions1(params, buffers, minibatches): | ||
return vmap(fmodel)(params, buffers, minibatches) | ||
|
||
predictions1_vmap = compute_predictions1(params, buffers, minibatches) | ||
|
||
# verify the ``vmap`` predictions match the | ||
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5) | ||
|
@@ -137,7 +143,11 @@ def fmodel(params, buffers, x): | |
# By using ``None``, we tell ``vmap`` we want the same minibatch to apply for all of | ||
# the 10 models. | ||
|
||
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) | ||
@torch.compile | ||
def compute_predictions2(params, buffers, minibatch): | ||
return vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) | ||
|
||
predictions2_vmap = compute_predictions2(params, buffers, minibatch) | ||
|
||
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,8 @@ | |
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch._dynamo import config | ||
config.inline_inbuilt_nn_modules = 1 | ||
torch.manual_seed(0) | ||
|
||
# Here's a simple CNN and loss function: | ||
|
@@ -52,7 +54,7 @@ def loss_fn(predictions, targets): | |
# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. | ||
# The dummy images are 28 by 28 and we use a minibatch of size 64. | ||
|
||
device = 'cuda' | ||
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' | ||
|
||
num_models = 10 | ||
batch_size = 64 | ||
|
@@ -162,7 +164,12 @@ def compute_loss(params, buffers, sample, target): | |
###################################################################### | ||
# Finally, let's used our transformed function to compute per-sample-gradients: | ||
|
||
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets) | ||
@torch.compile | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this function to below the text "Finally, let's use..."? |
||
def vmap_ft_compute_grad(params, buffers, data, targets): | ||
ft_compute_sample_grad_ = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) | ||
return ft_compute_sample_grad_(params, buffers, data, targets) | ||
|
||
ft_per_sample_grads = vmap_ft_compute_grad(params, buffers, data, targets) | ||
|
||
###################################################################### | ||
# we can double check that the results using ``grad`` and ``vmap`` match the | ||
|
@@ -194,7 +201,7 @@ def get_perf(first, first_descriptor, second, second_descriptor): | |
first_res = first.times[0] | ||
|
||
gain = (first_res-second_res)/first_res | ||
if gain < 0: gain *=-1 | ||
if gain < 0: gain *=-1 | ||
final_gain = gain*100 | ||
|
||
print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove, also, don't merge this PR until the next release (2.5)