Skip to content

@torch.compile some tutorials #2984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions intermediate_source/ensembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._dynamo import config
config.inline_inbuilt_nn_modules = 1
Comment on lines +28 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, also, don't merge this PR until the next release (2.5)

torch.manual_seed(0)

# Here's a simple MLP
Expand All @@ -50,7 +52,7 @@ def forward(self, x):
# minibatch of size 64. Furthermore, lets say we want to combine the predictions
# from 10 different models.

device = 'cuda'
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
Expand Down Expand Up @@ -125,7 +127,11 @@ def fmodel(params, buffers, x):

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
@torch.compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main comment is: all of these tutorials should have a separate section at the end that says "let's try to use torch.compile, and here are the speedups".

def compute_predictions1(params, buffers, minibatches):
return vmap(fmodel)(params, buffers, minibatches)

predictions1_vmap = compute_predictions1(params, buffers, minibatches)

# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
Expand All @@ -137,7 +143,11 @@ def fmodel(params, buffers, x):
# By using ``None``, we tell ``vmap`` we want the same minibatch to apply for all of
# the 10 models.

predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
@torch.compile
def compute_predictions2(params, buffers, minibatch):
return vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

predictions2_vmap = compute_predictions2(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

Expand Down
13 changes: 9 additions & 4 deletions intermediate_source/neural_tangent_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
from torch._dynamo import config
config.inline_inbuilt_nn_modules = 1
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'

class CNN(nn.Module):
Expand Down Expand Up @@ -95,6 +97,7 @@ def fnet_single(params, x):
# The first method consists of doing just that - computing the two Jacobians,
# and contracting them. Here's how to compute the NTK in the batched case:

@torch.compile
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
Expand All @@ -120,7 +123,8 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
# NTK where the non-diagonal elements can be approximated by zero. It's easy to
# adjust the above function to do that:

def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
@torch.compile
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = jac1.values()
Expand Down Expand Up @@ -189,7 +193,8 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='ful
#
# Let's code that up:

def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
@torch.compile
def empirical_ntk_ntk_vps(func, params, x1, x2, compute):
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
Expand Down Expand Up @@ -226,8 +231,8 @@ def get_ntk_slice(vec):

# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
with torch.backends.cudnn.flags(allow_tf32=False):
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train, 'full')
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train, 'full')

assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)

Expand Down
13 changes: 10 additions & 3 deletions intermediate_source/per_sample_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._dynamo import config
config.inline_inbuilt_nn_modules = 1
torch.manual_seed(0)

# Here's a simple CNN and loss function:
Expand Down Expand Up @@ -52,7 +54,7 @@ def loss_fn(predictions, targets):
# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset.
# The dummy images are 28 by 28 and we use a minibatch of size 64.

device = 'cuda'
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'

num_models = 10
batch_size = 64
Expand Down Expand Up @@ -162,7 +164,12 @@ def compute_loss(params, buffers, sample, target):
######################################################################
# Finally, let's used our transformed function to compute per-sample-gradients:

ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
@torch.compile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this function to below the text "Finally, let's use..."?

def vmap_ft_compute_grad(params, buffers, data, targets):
ft_compute_sample_grad_ = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
return ft_compute_sample_grad_(params, buffers, data, targets)

ft_per_sample_grads = vmap_ft_compute_grad(params, buffers, data, targets)

######################################################################
# we can double check that the results using ``grad`` and ``vmap`` match the
Expand Down Expand Up @@ -194,7 +201,7 @@ def get_perf(first, first_descriptor, second, second_descriptor):
first_res = first.times[0]

gain = (first_res-second_res)/first_res
if gain < 0: gain *=-1
if gain < 0: gain *=-1
final_gain = gain*100

print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
Expand Down
Loading