diff --git a/intermediate_source/ensembling.py b/intermediate_source/ensembling.py index 9199daf13a..cb2f42df68 100644 --- a/intermediate_source/ensembling.py +++ b/intermediate_source/ensembling.py @@ -50,7 +50,7 @@ def forward(self, x): # minibatch of size 64. Furthermore, lets say we want to combine the predictions # from 10 different models. -device = 'cuda' +device = torch.accelerator.current_accelerator() num_models = 10 data = torch.randn(100, 64, 1, 28, 28, device=device)