Skip to content

Commit d9abea1

Browse files
committed
tests
1 parent 1938413 commit d9abea1

File tree

4 files changed

+125
-2
lines changed

4 files changed

+125
-2
lines changed

tests/integration/synthesizer/test_ctgan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_fixed_random_seed():
236236
})
237237
discrete_columns = ['discrete']
238238

239-
ctgan = CTGAN(epochs=1, cuda=False)
239+
ctgan = CTGAN(epochs=1, enable_gpu=False)
240240

241241
# Run
242242
ctgan.fit(data, discrete_columns)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import re
2+
import sys
3+
from unittest.mock import patch
4+
5+
import pytest
6+
import torch
7+
8+
from ctgan.synthesizers._utils import _set_device, _validate_gpu_parameters, validate_and_set_device
9+
10+
11+
def test__validate_gpu_parameter():
12+
"""Test the ``_validate_gpu_parameters`` method."""
13+
# Setup
14+
expected_error = re.escape(
15+
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
16+
'Please use only `enable_gpu`.'
17+
)
18+
expected_warning = re.escape(
19+
'`cuda` parameter is deprecated and will be removed in a future release. '
20+
'Please use `enable_gpu` instead.'
21+
)
22+
23+
# Run
24+
enable_gpu_1 = _validate_gpu_parameters(False, None)
25+
enable_gpu_2 = _validate_gpu_parameters(True, None)
26+
with pytest.warns(FutureWarning, match=expected_warning):
27+
enable_gpu_3 = _validate_gpu_parameters(True, False)
28+
29+
with pytest.raises(ValueError, match=expected_error):
30+
_validate_gpu_parameters(False, True)
31+
32+
# Assert
33+
assert enable_gpu_1 is False
34+
assert enable_gpu_2 is True
35+
assert enable_gpu_3 is False
36+
37+
38+
def test__set_device():
39+
"""Test the ``_set_device`` method."""
40+
# Run
41+
device_1 = _set_device(False)
42+
device_2 = _set_device(True)
43+
device_3 = _set_device(True, 'cpu')
44+
device_4 = _set_device(enable_gpu=False, device='cpu')
45+
46+
# Assert
47+
if (
48+
sys.platform == 'darwin'
49+
and getattr(torch.backends, 'mps', None)
50+
and torch.backends.mps.is_available()
51+
):
52+
expected_device_2 = torch.device('mps')
53+
elif torch.cuda.is_available():
54+
expected_device_2 = torch.device('cuda')
55+
else:
56+
expected_device_2 = torch.device('cpu')
57+
58+
assert device_1 == torch.device('cpu')
59+
assert device_2 == expected_device_2
60+
assert device_3 == torch.device('cpu')
61+
assert device_4 == torch.device('cpu')
62+
63+
64+
@patch('ctgan.synthesizers._utils._set_device')
65+
@patch('ctgan.synthesizers._utils._validate_gpu_parameters')
66+
def test_validate_and_set_device(mock_validate, mock_set_device):
67+
"""Test the ``validate_and_set_device`` method."""
68+
# Setup
69+
mock_validate.return_value = True
70+
mock_set_device.return_value = torch.device('cuda')
71+
72+
# Run
73+
device = validate_and_set_device(enable_gpu=True, cuda=None)
74+
75+
# Assert
76+
mock_validate.assert_called_once_with(True, None)
77+
mock_set_device.assert_called_once_with(True)
78+
assert device == torch.device('cuda')

tests/unit/synthesizer/test_ctgan.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""CTGAN unit testing module."""
22

33
from unittest import TestCase
4-
from unittest.mock import Mock
4+
from unittest.mock import Mock, patch
55

66
import numpy as np
77
import pandas as pd
@@ -175,6 +175,30 @@ def _assert_is_between(data, lower, upper):
175175

176176

177177
class TestCTGAN(TestCase):
178+
@patch('ctgan.synthesizers.ctgan.validate_and_set_device')
179+
def test___init__(self, mock_validate_and_set_device):
180+
"""Test the `__init__` method."""
181+
# Setup
182+
mock_validate_and_set_device.return_value = 'cpu'
183+
184+
# Run
185+
synth = CTGAN()
186+
187+
# Assert
188+
assert synth._embedding_dim == 128
189+
assert synth._generator_dim == (256, 256)
190+
assert synth._discriminator_dim == (256, 256)
191+
assert synth._batch_size == 500
192+
assert synth._epochs == 300
193+
assert synth.pac == 10
194+
assert synth.loss_values is None
195+
assert synth._generator is None
196+
assert synth._data_sampler is None
197+
assert synth._verbose is False
198+
assert synth._enable_gpu is True
199+
assert synth._device == 'cpu'
200+
mock_validate_and_set_device.assert_called_once_with(True, None)
201+
178202
def test__apply_activate_(self):
179203
"""Test `_apply_activate` for tables with both continuous and categoricals.
180204

tests/unit/synthesizer/test_tvae.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,27 @@
88

99

1010
class TestTVAE:
11+
@patch('ctgan.synthesizers.tvae.validate_and_set_device')
12+
def test___init__(self, mock_validate_and_set_device):
13+
"""Test the `__init__` method."""
14+
# Setup
15+
mock_validate_and_set_device.return_value = 'cpu'
16+
17+
# Run
18+
synth = TVAE()
19+
20+
# Assert
21+
assert synth.embedding_dim == 128
22+
assert synth.compress_dims == (128, 128)
23+
assert synth.decompress_dims == (128, 128)
24+
assert synth.batch_size == 500
25+
assert synth.epochs == 300
26+
assert synth.loss_values.equals(pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']))
27+
assert synth.verbose is False
28+
assert synth._enable_gpu is True
29+
assert synth._device == 'cpu'
30+
mock_validate_and_set_device.assert_called_once_with(True, None)
31+
1132
@patch('ctgan.synthesizers.tvae._loss_function')
1233
@patch('ctgan.synthesizers.tvae.tqdm')
1334
def test_fit_verbose(self, tqdm_mock, loss_func_mock):

0 commit comments

Comments
 (0)