Skip to content

Commit 48b0a2f

Browse files
committed
improve code
1 parent 9bf32e5 commit 48b0a2f

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

deepecho/models/basic_gan.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,15 @@ def __init__(
186186
if enable_gpu:
187187
if sys.platform == 'darwin': # macOS
188188
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
189-
device = torch.device('mps')
189+
device = 'mps'
190190
else:
191-
device = torch.device('cpu')
191+
device = 'cpu'
192192
else: # Linux/Windows
193-
if torch.cuda.is_available():
194-
device = torch.device('cuda')
195-
else:
196-
device = torch.device('cpu')
193+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
197194
else:
198-
device = torch.device('cpu')
195+
device = 'cpu'
199196

200-
self._device = device
197+
self._device = torch.device(device)
201198
self._verbose = verbose
202199

203200
LOGGER.info('%s instance created', self)

tests/integration/test_basic_gan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_deprecation_warning(self):
3030
def test__init___enable_gpu(self):
3131
"""Test when `enable_gpu` parameter in the constructor."""
3232
# Setup and Run
33-
model = BasicGANModel(epochs=10, enable_gpu=True)
33+
model = BasicGANModel(epochs=10)
3434

3535
# Assert
3636
if (

0 commit comments

Comments
 (0)