Skip to content

Commit 3e557e7

Browse files
address reviewer comment
1 parent d8e6e12 commit 3e557e7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,14 @@ def compute_loss(params, buffers, sample, target):
162162

163163
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
164164

165+
######################################################################
166+
# Finally, let's used our transformed function to compute per-sample-gradients:
167+
165168
@torch.compile
166169
def vmap_ft_compute_grad(params, buffers, data, targets):
167170
ft_compute_sample_grad_ = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
168171
return ft_compute_sample_grad_(params, buffers, data, targets)
169172

170-
######################################################################
171-
# Finally, let's used our transformed function to compute per-sample-gradients:
172-
173173
ft_per_sample_grads = vmap_ft_compute_grad(params, buffers, data, targets)
174174
profile_utils.compute_speedup(vmap_ft_compute_grad, (params, buffers, data, targets), device)
175175

0 commit comments

Comments
 (0)