Skip to content

Commit d4d86b3

Browse files
committed
[4/N] Refine beginner tutorial by accelerator api
1 parent 540bd0c commit d4d86b3

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

beginner_source/transfer_learning_tutorial.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@
9898
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
9999
class_names = image_datasets['train'].classes
100100

101-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
101+
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
102+
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.
103+
104+
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
105+
print(f"Using {device} device")
102106

103107
######################################################################
104108
# Visualize a few images

0 commit comments

Comments
 (0)