Skip to content

Commit e782774

Browse files
committed
fix tests
1 parent d894122 commit e782774

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tests/integration/test_basic_gan.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ def test__init___enable_gpu(self):
3333
model = BasicGANModel(epochs=10, enable_gpu=True)
3434

3535
# Assert
36-
os_to_device = {
37-
'darwin': torch.device('mps' if torch.backends.mps.is_available() else 'cpu'),
38-
'linux': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
39-
'win32': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
40-
}
41-
expected_device = os_to_device.get(sys.platform, torch.device('cpu'))
36+
if sys.platform == 'darwin' and torch.backends.mps.is_available():
37+
expected_device = torch.device('mps')
38+
elif torch.cuda.is_available():
39+
expected_device = torch.device('cuda')
40+
else:
41+
expected_device = torch.device('cpu')
42+
4243
assert model._device == expected_device
4344
assert model._enable_gpu is True
4445

0 commit comments

Comments
 (0)