File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -33,12 +33,13 @@ def test__init___enable_gpu(self):
3333 model = BasicGANModel (epochs = 10 , enable_gpu = True )
3434
3535 # Assert
36- os_to_device = {
37- 'darwin' : torch .device ('mps' if torch .backends .mps .is_available () else 'cpu' ),
38- 'linux' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
39- 'win32' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
40- }
41- expected_device = os_to_device .get (sys .platform , torch .device ('cpu' ))
36+ if sys .platform == 'darwin' and torch .backends .mps .is_available ():
37+ expected_device = torch .device ('mps' )
38+ elif torch .cuda .is_available ():
39+ expected_device = torch .device ('cuda' )
40+ else :
41+ expected_device = torch .device ('cpu' )
42+
4243 assert model ._device == expected_device
4344 assert model ._enable_gpu is True
4445
You can’t perform that action at this time.
0 commit comments