Skip to content

Commit 785e38c

Browse files
authored
Update per_sample_grads.py
1 parent de2609d commit 785e38c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def loss_fn(predictions, targets):
5252
# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset.
5353
# The dummy images are 28 by 28 and we use a minibatch of size 64.
5454

55-
device = 'cpu'
55+
device = 'cuda'
5656

5757
num_models = 10
5858
batch_size = 64

0 commit comments

Comments
 (0)