Skip to content

Commit 4c5e7d8

Browse files
authored
Fix model cuda (#2484)
1 parent df07f8e commit 4c5e7d8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

model_zoo/mnist/mnist_pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ def train(args):
162162
epoch = 0
163163
# Use the elastic function to wrap the training function with a batch.
164164
elastic_train_one_batch = allreduce_controller.elastic_run(train_one_batch)
165+
if torch.cuda.is_available():
166+
model.cuda()
165167
with allreduce_controller.scope():
166168
for batch_idx, (data, target) in enumerate(train_loader):
167169
model.train()

0 commit comments

Comments
 (0)