Skip to content

Commit 753be57

Browse files
@torch.compile some tutorials
1 parent a0a9e3b commit 753be57

File tree

4 files changed

+96
-10
lines changed

4 files changed

+96
-10
lines changed

intermediate_source/ensembling.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
import torch
2626
import torch.nn as nn
2727
import torch.nn.functional as F
28+
from torch._dynamo import config
29+
config.inline_inbuilt_nn_modules = 1
30+
import profile_utils
2831
torch.manual_seed(0)
2932

3033
# Here's a simple MLP
@@ -50,7 +53,7 @@ def forward(self, x):
5053
# minibatch of size 64. Furthermore, lets say we want to combine the predictions
5154
# from 10 different models.
5255

53-
device = 'cuda'
56+
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
5457
num_models = 10
5558

5659
data = torch.randn(100, 64, 1, 28, 28, device=device)
@@ -125,7 +128,12 @@ def fmodel(params, buffers, x):
125128

126129
from torch import vmap
127130

128-
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
131+
@torch.compile
132+
def compute_predictions1(params, buffers, minibatches):
133+
return vmap(fmodel)(params, buffers, minibatches)
134+
135+
predictions1_vmap = compute_predictions1(params, buffers, minibatches)
136+
profile_utils.compute_speedup(compute_predictions1, (params, buffers, minibatches), device)
129137

130138
# verify the ``vmap`` predictions match the
131139
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
@@ -137,7 +145,12 @@ def fmodel(params, buffers, x):
137145
# By using ``None``, we tell ``vmap`` we want the same minibatch to apply for all of
138146
# the 10 models.
139147

140-
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
148+
@torch.compile
149+
def compute_predictions2(params, buffers, minibatch):
150+
return vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
151+
152+
predictions2_vmap = compute_predictions2(params, buffers, minibatch)
153+
profile_utils.compute_speedup(compute_predictions2, (params, buffers, minibatch), device)
141154

142155
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
143156

intermediate_source/neural_tangent_kernels.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@
2222
"""
2323

2424
import torch
25+
import profile_utils
2526
import torch.nn as nn
2627
from torch.func import functional_call, vmap, vjp, jvp, jacrev
28+
from torch._dynamo import config
29+
config.inline_inbuilt_nn_modules = 1
2730
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
2831

2932
class CNN(nn.Module):
@@ -95,6 +98,7 @@ def fnet_single(params, x):
9598
# The first method consists of doing just that - computing the two Jacobians,
9699
# and contracting them. Here's how to compute the NTK in the batched case:
97100

101+
@torch.compile
98102
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
99103
# Compute J(x1)
100104
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
@@ -113,14 +117,16 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
113117

114118
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
115119
print(result.shape)
120+
profile_utils.compute_speedup(empirical_ntk_jacobian_contraction, (fnet_single, params, x_train, x_test), device)
116121

117122
######################################################################
118123
# In some cases, you may only want the diagonal or the trace of this quantity,
119124
# especially if you know beforehand that the network architecture results in an
120125
# NTK where the non-diagonal elements can be approximated by zero. It's easy to
121126
# adjust the above function to do that:
122127

123-
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
128+
@torch.compile
129+
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute):
124130
# Compute J(x1)
125131
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
126132
jac1 = jac1.values()
@@ -148,6 +154,7 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='ful
148154

149155
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
150156
print(result.shape)
157+
profile_utils.compute_speedup(empirical_ntk_jacobian_contraction, (fnet_single, params, x_train, x_test, 'trace'), device)
151158

152159
######################################################################
153160
# The asymptotic time complexity of this method is :math:`N O [FP]` (time to
@@ -189,7 +196,8 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='ful
189196
#
190197
# Let's code that up:
191198

192-
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
199+
@torch.compile
200+
def empirical_ntk_ntk_vps(func, params, x1, x2, compute):
193201
def get_ntk(x1, x2):
194202
def func_x1(params):
195203
return func(params, x1)
@@ -226,8 +234,9 @@ def get_ntk_slice(vec):
226234

227235
# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
228236
with torch.backends.cudnn.flags(allow_tf32=False):
229-
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
230-
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
237+
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train, 'full')
238+
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train, 'full')
239+
profile_utils.compute_speedup(empirical_ntk_ntk_vps, (fnet_single, params, x_train, x_test, 'full'), device)
231240

232241
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
233242

intermediate_source/per_sample_grads.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
1717
"""
1818

19+
import profile_utils
1920
import torch
2021
import torch.nn as nn
2122
import torch.nn.functional as F
23+
from torch._dynamo import config
24+
config.inline_inbuilt_nn_modules = 1
2225
torch.manual_seed(0)
2326

2427
# Here's a simple CNN and loss function:
@@ -52,7 +55,7 @@ def loss_fn(predictions, targets):
5255
# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset.
5356
# The dummy images are 28 by 28 and we use a minibatch of size 64.
5457

55-
device = 'cuda'
58+
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
5659

5760
num_models = 10
5861
batch_size = 64
@@ -159,10 +162,16 @@ def compute_loss(params, buffers, sample, target):
159162

160163
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
161164

165+
@torch.compile
166+
def vmap_ft_compute_grad(params, buffers, data, targets):
167+
ft_compute_sample_grad_ = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
168+
return ft_compute_sample_grad_(params, buffers, data, targets)
169+
162170
######################################################################
163171
# Finally, let's used our transformed function to compute per-sample-gradients:
164172

165-
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
173+
ft_per_sample_grads = vmap_ft_compute_grad(params, buffers, data, targets)
174+
profile_utils.compute_speedup(vmap_ft_compute_grad, (params, buffers, data, targets), device)
166175

167176
######################################################################
168177
# we can double check that the results using ``grad`` and ``vmap`` match the
@@ -194,7 +203,7 @@ def get_perf(first, first_descriptor, second, second_descriptor):
194203
first_res = first.times[0]
195204

196205
gain = (first_res-second_res)/first_res
197-
if gain < 0: gain *=-1
206+
if gain < 0: gain *=-1
198207
final_gain = gain*100
199208

200209
print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")

intermediate_source/profile_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from torch.fx.experimental.proxy_tensor import make_fx
3+
from torch.utils.benchmark import Timer, Compare
4+
5+
def profile(fn, inputs):
6+
activities = [
7+
torch.profiler.ProfilerActivity.CPU,
8+
torch.profiler.ProfilerActivity.CUDA,
9+
]
10+
11+
with torch.profiler.profile(activities=activities, with_stack=True) as prof:
12+
fn(*inputs)
13+
14+
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
15+
16+
def compute_speedup(fn, inputs, device, times=100):
17+
lst = []
18+
19+
fn = fn._torchdynamo_orig_callable
20+
fn_opt = torch.compile(fullgraph=True)(fn)
21+
fx_g = make_fx(fn)
22+
23+
for nt in (1, 2, 4, 8, 16):
24+
opt = Timer(
25+
setup='fn_opt(*inputs)',
26+
stmt='fn_opt(*inputs)',
27+
globals={'fn_opt': fn_opt, 'inputs': inputs},
28+
label=fn.__name__,
29+
sub_label='@torch.compile',
30+
description=device,
31+
num_threads=nt,
32+
).timeit(times)
33+
34+
fx = Timer(
35+
setup='fx_g(*inputs)',
36+
stmt='fx_g(*inputs)',
37+
globals={'fx_g': fx_g, 'inputs': inputs},
38+
label=fn.__name__,
39+
sub_label='make_fx',
40+
description=device,
41+
num_threads=nt,
42+
).timeit(times)
43+
44+
eager = Timer(
45+
setup='fn(*inputs)',
46+
stmt='fn(*inputs)',
47+
globals={'fn': fn, 'inputs': inputs},
48+
label=fn.__name__,
49+
sub_label='eager',
50+
description=device,
51+
num_threads=nt,
52+
).timeit(times)
53+
lst.extend([opt, fx, eager])
54+
55+
Compare(lst).print()

0 commit comments

Comments
 (0)