Skip to content

Commit 6baaf46

Browse files
authored
1 parent c391f15 commit 6baaf46

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def loss_fn(predictions, targets):
7979
# pass to get an individual (per-sample) gradient.
8080

8181
def compute_grad(sample, target):
82+
model.zero_grad()
8283
sample = sample.unsqueeze(0) # prepend batch dimension for processing
8384
target = target.unsqueeze(0)
8485

0 commit comments

Comments
 (0)