Skip to content

Commit d894122

Browse files
committed
tests
1 parent 3680417 commit d894122

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

tests/integration/test_basic_gan.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,47 @@
11
"""Integration tests for ``BasicGANModel``."""
22

3+
import re
4+
import sys
35
import unittest
46

7+
import pytest
8+
import torch
9+
510
from deepecho.models.basic_gan import BasicGANModel
611

712

813
class TestBasicGANModel(unittest.TestCase):
914
"""Test class for the ``BasicGANModel``."""
1015

16+
def test_deprecation_warning(self):
17+
"""Test that using the deprecated `cuda` parameter raises a warning."""
18+
# Setup
19+
expected_message = re.escape(
20+
'`cuda` parameter is deprecated and will be removed in a future release. '
21+
'Please use `enable_gpu` instead.'
22+
)
23+
24+
# Run and Assert
25+
with pytest.warns(FutureWarning, match=expected_message):
26+
model = BasicGANModel(epochs=10, cuda=False)
27+
28+
assert model._enable_gpu is False
29+
30+
def test__init___enable_gpu(self):
31+
"""Test when `enable_gpu` parameter in the constructor."""
32+
# Setup and Run
33+
model = BasicGANModel(epochs=10, enable_gpu=True)
34+
35+
# Assert
36+
os_to_device = {
37+
'darwin': torch.device('mps' if torch.backends.mps.is_available() else 'cpu'),
38+
'linux': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
39+
'win32': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
40+
}
41+
expected_device = os_to_device.get(sys.platform, torch.device('cpu'))
42+
assert model._device == expected_device
43+
assert model._enable_gpu is True
44+
1145
def test_basic(self):
1246
"""Basic test for the ``BasicGANModel``."""
1347
sequences = [

0 commit comments

Comments
 (0)