Skip to content

Commit da987e5

Browse files
remove compute_speedup call from tutorials
1 parent c6d308e commit da987e5

File tree

3 files changed

+0
-9
lines changed

3 files changed

+0
-9
lines changed

intermediate_source/ensembling.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import torch.nn.functional as F
2828
from torch._dynamo import config
2929
config.inline_inbuilt_nn_modules = 1
30-
import profile_utils
3130
torch.manual_seed(0)
3231

3332
# Here's a simple MLP
@@ -133,7 +132,6 @@ def compute_predictions1(params, buffers, minibatches):
133132
return vmap(fmodel)(params, buffers, minibatches)
134133

135134
predictions1_vmap = compute_predictions1(params, buffers, minibatches)
136-
profile_utils.compute_speedup(compute_predictions1, (params, buffers, minibatches), device)
137135

138136
# verify the ``vmap`` predictions match the
139137
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
@@ -150,7 +148,6 @@ def compute_predictions2(params, buffers, minibatch):
150148
return vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
151149

152150
predictions2_vmap = compute_predictions2(params, buffers, minibatch)
153-
profile_utils.compute_speedup(compute_predictions2, (params, buffers, minibatch), device)
154151

155152
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
156153

intermediate_source/neural_tangent_kernels.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""
2323

2424
import torch
25-
import profile_utils
2625
import torch.nn as nn
2726
from torch.func import functional_call, vmap, vjp, jvp, jacrev
2827
from torch._dynamo import config
@@ -117,7 +116,6 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
117116

118117
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
119118
print(result.shape)
120-
profile_utils.compute_speedup(empirical_ntk_jacobian_contraction, (fnet_single, params, x_train, x_test), device)
121119

122120
######################################################################
123121
# In some cases, you may only want the diagonal or the trace of this quantity,
@@ -154,7 +152,6 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute):
154152

155153
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
156154
print(result.shape)
157-
profile_utils.compute_speedup(empirical_ntk_jacobian_contraction, (fnet_single, params, x_train, x_test, 'trace'), device)
158155

159156
######################################################################
160157
# The asymptotic time complexity of this method is :math:`N O [FP]` (time to
@@ -236,7 +233,6 @@ def get_ntk_slice(vec):
236233
with torch.backends.cudnn.flags(allow_tf32=False):
237234
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train, 'full')
238235
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train, 'full')
239-
profile_utils.compute_speedup(empirical_ntk_ntk_vps, (fnet_single, params, x_train, x_test, 'full'), device)
240236

241237
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
242238

intermediate_source/per_sample_grads.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
1717
"""
1818

19-
import profile_utils
2019
import torch
2120
import torch.nn as nn
2221
import torch.nn.functional as F
@@ -171,7 +170,6 @@ def vmap_ft_compute_grad(params, buffers, data, targets):
171170
return ft_compute_sample_grad_(params, buffers, data, targets)
172171

173172
ft_per_sample_grads = vmap_ft_compute_grad(params, buffers, data, targets)
174-
profile_utils.compute_speedup(vmap_ft_compute_grad, (params, buffers, data, targets), device)
175173

176174
######################################################################
177175
# we can double check that the results using ``grad`` and ``vmap`` match the

0 commit comments

Comments
 (0)