|
| 1 | +import platform |
1 | 2 | import re |
2 | 3 | from unittest.mock import patch |
3 | 4 |
|
4 | 5 | import numpy as np |
5 | 6 | import pandas as pd |
6 | 7 | import pytest |
| 8 | +import torch |
7 | 9 | from rdt.transformers import FloatFormatter, LabelEncoder |
8 | 10 |
|
9 | 11 | from sdv.cag import FixedCombinations |
10 | 12 | from sdv.datasets.demo import download_demo |
11 | 13 | from sdv.errors import InvalidDataTypeError |
12 | 14 | from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot |
13 | 15 | from sdv.metadata.metadata import Metadata |
14 | | -from sdv.single_table import CTGANSynthesizer, TVAESynthesizer |
| 16 | +from sdv.single_table import CopulaGANSynthesizer, CTGANSynthesizer, TVAESynthesizer |
15 | 17 |
|
16 | 18 |
|
17 | 19 | def test__estimate_num_columns(): |
@@ -331,3 +333,61 @@ def test_tvae___init___without_torch(mock_import_error): |
331 | 333 | # Run and Assert |
332 | 334 | with pytest.raises(ModuleNotFoundError, match=msg): |
333 | 335 | TVAESynthesizer(metadata) |
| 336 | + |
| 337 | + |
| 338 | +@pytest.mark.parametrize( |
| 339 | + 'synthesizer_class', [CTGANSynthesizer, TVAESynthesizer, CopulaGANSynthesizer] |
| 340 | +) |
| 341 | +def test_enable_gpu_parameter(synthesizer_class): |
| 342 | + """Test that the `enable_gpu` parameter is correctly passed to the underlying model.""" |
| 343 | + # Setup |
| 344 | + data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') |
| 345 | + expected_warning = re.escape( |
| 346 | + '`cuda` parameter is deprecated and will be removed in a future release. ' |
| 347 | + 'Please use `enable_gpu` instead.' |
| 348 | + ) |
| 349 | + expected_error = re.escape( |
| 350 | + 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. ' |
| 351 | + 'Please use only `enable_gpu`.' |
| 352 | + ) |
| 353 | + |
| 354 | + # Run |
| 355 | + synthesizer_1 = synthesizer_class(metadata) |
| 356 | + synthesizer_2 = synthesizer_class(metadata, enable_gpu=False) |
| 357 | + with pytest.warns(FutureWarning, match=expected_warning): |
| 358 | + synthesizer_3 = synthesizer_class(metadata, cuda=True) |
| 359 | + |
| 360 | + with pytest.raises(ValueError, match=expected_error): |
| 361 | + synthesizer_class(metadata, enable_gpu=False, cuda=True) |
| 362 | + |
| 363 | + synthesizer_1.fit(data) |
| 364 | + synthesizer_2.fit(data) |
| 365 | + synthesizer_3.fit(data) |
| 366 | + synthetic_data_1 = synthesizer_1.sample(10) |
| 367 | + synthetic_data_2 = synthesizer_2.sample(10) |
| 368 | + synthetic_data_3 = synthesizer_3.sample(10) |
| 369 | + |
| 370 | + # Assert |
| 371 | + data_columns = data.columns.tolist() |
| 372 | + if ( |
| 373 | + platform.machine() == 'arm64' |
| 374 | + and getattr(torch.backends, 'mps', None) |
| 375 | + and torch.backends.mps.is_available() |
| 376 | + ): |
| 377 | + expected_device = torch.device('mps') |
| 378 | + elif torch.cuda.is_available(): |
| 379 | + expected_device = torch.device('cuda') |
| 380 | + else: |
| 381 | + expected_device = torch.device('cpu') |
| 382 | + |
| 383 | + assert synthesizer_1._model._enable_gpu is True |
| 384 | + assert synthesizer_1._model._device == expected_device |
| 385 | + assert synthesizer_2._model._enable_gpu is False |
| 386 | + assert synthesizer_2._model._device == torch.device('cpu') |
| 387 | + assert synthesizer_3._model._enable_gpu is True |
| 388 | + assert synthesizer_3._model._device == expected_device |
| 389 | + assert synthetic_data_1.columns.tolist() == data_columns |
| 390 | + assert synthetic_data_2.columns.tolist() == data_columns |
| 391 | + assert synthetic_data_3.columns.tolist() == data_columns |
| 392 | + assert len(synthetic_data_1) == 10 |
| 393 | + assert len(synthetic_data_2) == len(synthetic_data_3) == 10 |
0 commit comments