|
1 | 1 | """Integration tests for ``BasicGANModel``.""" |
2 | 2 |
|
| 3 | +import re |
| 4 | +import sys |
3 | 5 | import unittest |
4 | 6 |
|
| 7 | +import pytest |
| 8 | +import torch |
| 9 | + |
5 | 10 | from deepecho.models.basic_gan import BasicGANModel |
6 | 11 |
|
7 | 12 |
|
8 | 13 | class TestBasicGANModel(unittest.TestCase): |
9 | 14 | """Test class for the ``BasicGANModel``.""" |
10 | 15 |
|
| 16 | + def test_deprecation_warning(self): |
| 17 | + """Test that using the deprecated `cuda` parameter raises a warning.""" |
| 18 | + # Setup |
| 19 | + expected_message = re.escape( |
| 20 | + '`cuda` parameter is deprecated and will be removed in a future release. ' |
| 21 | + 'Please use `enable_gpu` instead.' |
| 22 | + ) |
| 23 | + |
| 24 | + # Run and Assert |
| 25 | + with pytest.warns(FutureWarning, match=expected_message): |
| 26 | + model = BasicGANModel(epochs=10, cuda=False) |
| 27 | + |
| 28 | + assert model._enable_gpu is False |
| 29 | + |
| 30 | + def test__init___enable_gpu(self): |
| 31 | + """Test when `enable_gpu` parameter in the constructor.""" |
| 32 | + # Setup and Run |
| 33 | + model = BasicGANModel(epochs=10, enable_gpu=True) |
| 34 | + |
| 35 | + # Assert |
| 36 | + os_to_device = { |
| 37 | + 'darwin': torch.device('mps' if torch.backends.mps.is_available() else 'cpu'), |
| 38 | + 'linux': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), |
| 39 | + 'win32': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), |
| 40 | + } |
| 41 | + expected_device = os_to_device.get(sys.platform, torch.device('cpu')) |
| 42 | + assert model._device == expected_device |
| 43 | + assert model._enable_gpu is True |
| 44 | + |
11 | 45 | def test_basic(self): |
12 | 46 | """Basic test for the ``BasicGANModel``.""" |
13 | 47 | sequences = [ |
|
0 commit comments