|
22 | 22 | """
|
23 | 23 |
|
24 | 24 | import torch
|
25 |
| -import profile_utils |
26 | 25 | import torch.nn as nn
|
27 | 26 | from torch.func import functional_call, vmap, vjp, jvp, jacrev
|
28 | 27 | from torch._dynamo import config
|
@@ -117,7 +116,6 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
|
117 | 116 |
|
118 | 117 | result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
|
119 | 118 | print(result.shape)
|
120 |
| -profile_utils.compute_speedup(empirical_ntk_jacobian_contraction, (fnet_single, params, x_train, x_test), device) |
121 | 119 |
|
122 | 120 | ######################################################################
|
123 | 121 | # In some cases, you may only want the diagonal or the trace of this quantity,
|
@@ -154,7 +152,6 @@ def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute):
|
154 | 152 |
|
155 | 153 | result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
|
156 | 154 | print(result.shape)
|
157 |
| -profile_utils.compute_speedup(empirical_ntk_jacobian_contraction, (fnet_single, params, x_train, x_test, 'trace'), device) |
158 | 155 |
|
159 | 156 | ######################################################################
|
160 | 157 | # The asymptotic time complexity of this method is :math:`N O [FP]` (time to
|
@@ -236,7 +233,6 @@ def get_ntk_slice(vec):
|
236 | 233 | with torch.backends.cudnn.flags(allow_tf32=False):
|
237 | 234 | result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train, 'full')
|
238 | 235 | result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train, 'full')
|
239 |
| - profile_utils.compute_speedup(empirical_ntk_ntk_vps, (fnet_single, params, x_train, x_test, 'full'), device) |
240 | 236 |
|
241 | 237 | assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
|
242 | 238 |
|
|
0 commit comments