|
| 1 | +import platform |
| 2 | +import re |
| 3 | +from unittest.mock import patch |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | + |
| 8 | +from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device |
| 9 | + |
| 10 | + |
| 11 | +def test__validate_gpu_parameter(): |
| 12 | + """Test the ``_get_enable_gpu_value`` method.""" |
| 13 | + # Setup |
| 14 | + expected_error = re.escape( |
| 15 | + 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. ' |
| 16 | + 'Please use only `enable_gpu`.' |
| 17 | + ) |
| 18 | + expected_warning = re.escape( |
| 19 | + '`cuda` parameter is deprecated and will be removed in a future release. ' |
| 20 | + 'Please use `enable_gpu` instead.' |
| 21 | + ) |
| 22 | + |
| 23 | + # Run |
| 24 | + enable_gpu_1 = _get_enable_gpu_value(False, None) |
| 25 | + enable_gpu_2 = _get_enable_gpu_value(True, None) |
| 26 | + with pytest.warns(FutureWarning, match=expected_warning): |
| 27 | + enable_gpu_3 = _get_enable_gpu_value(True, False) |
| 28 | + |
| 29 | + with pytest.raises(ValueError, match=expected_error): |
| 30 | + _get_enable_gpu_value(False, True) |
| 31 | + |
| 32 | + # Assert |
| 33 | + assert enable_gpu_1 is False |
| 34 | + assert enable_gpu_2 is True |
| 35 | + assert enable_gpu_3 is False |
| 36 | + |
| 37 | + |
| 38 | +def test__set_device(): |
| 39 | + """Test the ``_set_device`` method.""" |
| 40 | + # Run |
| 41 | + device_1 = _set_device(False) |
| 42 | + device_2 = _set_device(True) |
| 43 | + device_3 = _set_device(True, 'cpu') |
| 44 | + device_4 = _set_device(enable_gpu=False, device='cpu') |
| 45 | + |
| 46 | + # Assert |
| 47 | + if ( |
| 48 | + platform.machine() == 'arm64' |
| 49 | + and getattr(torch.backends, 'mps', None) |
| 50 | + and torch.backends.mps.is_available() |
| 51 | + ): |
| 52 | + expected_device_2 = torch.device('mps') |
| 53 | + elif torch.cuda.is_available(): |
| 54 | + expected_device_2 = torch.device('cuda') |
| 55 | + else: |
| 56 | + expected_device_2 = torch.device('cpu') |
| 57 | + |
| 58 | + assert device_1 == torch.device('cpu') |
| 59 | + assert device_2 == expected_device_2 |
| 60 | + assert device_3 == torch.device('cpu') |
| 61 | + assert device_4 == torch.device('cpu') |
| 62 | + |
| 63 | + |
| 64 | +@patch('ctgan.synthesizers._utils._set_device') |
| 65 | +@patch('ctgan.synthesizers._utils._get_enable_gpu_value') |
| 66 | +def test_validate_and_set_device(mock_validate, mock_set_device): |
| 67 | + """Test the ``validate_and_set_device`` method.""" |
| 68 | + # Setup |
| 69 | + mock_validate.return_value = True |
| 70 | + mock_set_device.return_value = torch.device('cuda') |
| 71 | + |
| 72 | + # Run |
| 73 | + device = validate_and_set_device(enable_gpu=True, cuda=None) |
| 74 | + |
| 75 | + # Assert |
| 76 | + mock_validate.assert_called_once_with(True, None) |
| 77 | + mock_set_device.assert_called_once_with(True) |
| 78 | + assert device == torch.device('cuda') |
0 commit comments