File tree Expand file tree Collapse file tree 2 files changed +6
-9
lines changed
Expand file tree Collapse file tree 2 files changed +6
-9
lines changed Original file line number Diff line number Diff line change @@ -186,18 +186,15 @@ def __init__(
186186 if enable_gpu :
187187 if sys .platform == 'darwin' : # macOS
188188 if getattr (torch .backends , 'mps' , None ) and torch .backends .mps .is_available ():
189- device = torch . device ( 'mps' )
189+ device = 'mps'
190190 else :
191- device = torch . device ( 'cpu' )
191+ device = 'cpu'
192192 else : # Linux/Windows
193- if torch .cuda .is_available ():
194- device = torch .device ('cuda' )
195- else :
196- device = torch .device ('cpu' )
193+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
197194 else :
198- device = torch . device ( 'cpu' )
195+ device = 'cpu'
199196
200- self ._device = device
197+ self ._device = torch . device ( device )
201198 self ._verbose = verbose
202199
203200 LOGGER .info ('%s instance created' , self )
Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ def test_deprecation_warning(self):
3030 def test__init___enable_gpu (self ):
3131 """Test when `enable_gpu` parameter in the constructor."""
3232 # Setup and Run
33- model = BasicGANModel (epochs = 10 , enable_gpu = True )
33+ model = BasicGANModel (epochs = 10 )
3434
3535 # Assert
3636 if (
You can’t perform that action at this time.
0 commit comments