diff --git a/intermediate_source/ensembling.py b/intermediate_source/ensembling.py index 9199daf13a3..fb6edf02284 100644 --- a/intermediate_source/ensembling.py +++ b/intermediate_source/ensembling.py @@ -25,6 +25,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch._dynamo import config +config.inline_inbuilt_nn_modules = 1 torch.manual_seed(0) # Here's a simple MLP @@ -50,7 +52,7 @@ def forward(self, x): # minibatch of size 64. Furthermore, lets say we want to combine the predictions # from 10 different models. -device = 'cuda' +device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' num_models = 10 data = torch.randn(100, 64, 1, 28, 28, device=device) @@ -125,7 +127,11 @@ def fmodel(params, buffers, x): from torch import vmap -predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) +@torch.compile +def compute_predictions1(params, buffers, minibatches): + return vmap(fmodel)(params, buffers, minibatches) + +predictions1_vmap = compute_predictions1(params, buffers, minibatches) # verify the ``vmap`` predictions match the assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5) @@ -137,7 +143,11 @@ def fmodel(params, buffers, x): # By using ``None``, we tell ``vmap`` we want the same minibatch to apply for all of # the 10 models. -predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) +@torch.compile +def compute_predictions2(params, buffers, minibatch): + return vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) + +predictions2_vmap = compute_predictions2(params, buffers, minibatch) assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5) diff --git a/intermediate_source/neural_tangent_kernels.py b/intermediate_source/neural_tangent_kernels.py index 62a49794af5..56f169946fe 100644 --- a/intermediate_source/neural_tangent_kernels.py +++ b/intermediate_source/neural_tangent_kernels.py @@ -24,6 +24,8 @@ import torch import torch.nn as nn from torch.func import functional_call, vmap, vjp, jvp, jacrev +from torch._dynamo import config +config.inline_inbuilt_nn_modules = 1 device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' class CNN(nn.Module): @@ -95,6 +97,7 @@ def fnet_single(params, x): # The first method consists of doing just that - computing the two Jacobians, # and contracting them. Here's how to compute the NTK in the batched case: +@torch.compile def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2): # Compute J(x1) jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1) @@ -120,7 +123,8 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2): # NTK where the non-diagonal elements can be approximated by zero. It's easy to # adjust the above function to do that: -def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'): +@torch.compile +def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute): # Compute J(x1) jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1) jac1 = jac1.values() @@ -189,7 +193,8 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='ful # # Let's code that up: -def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'): +@torch.compile +def empirical_ntk_ntk_vps(func, params, x1, x2, compute): def get_ntk(x1, x2): def func_x1(params): return func(params, x1) @@ -226,8 +231,8 @@ def get_ntk_slice(vec): # Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy with torch.backends.cudnn.flags(allow_tf32=False): - result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train) - result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train) + result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train, 'full') + result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train, 'full') assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index ece80d3f94f..01f9931cb28 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -19,6 +19,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch._dynamo import config +config.inline_inbuilt_nn_modules = 1 torch.manual_seed(0) # Here's a simple CNN and loss function: @@ -52,7 +54,7 @@ def loss_fn(predictions, targets): # Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. # The dummy images are 28 by 28 and we use a minibatch of size 64. -device = 'cuda' +device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' num_models = 10 batch_size = 64 @@ -162,7 +164,12 @@ def compute_loss(params, buffers, sample, target): ###################################################################### # Finally, let's used our transformed function to compute per-sample-gradients: -ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets) +@torch.compile +def vmap_ft_compute_grad(params, buffers, data, targets): + ft_compute_sample_grad_ = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) + return ft_compute_sample_grad_(params, buffers, data, targets) + +ft_per_sample_grads = vmap_ft_compute_grad(params, buffers, data, targets) ###################################################################### # we can double check that the results using ``grad`` and ``vmap`` match the @@ -194,7 +201,7 @@ def get_perf(first, first_descriptor, second, second_descriptor): first_res = first.times[0] gain = (first_res-second_res)/first_res - if gain < 0: gain *=-1 + if gain < 0: gain *=-1 final_gain = gain*100 print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")