Skip to content

Commit 6d659cf

Browse files
committed
tests
1 parent 2144211 commit 6d659cf

File tree

9 files changed

+257
-71
lines changed

9 files changed

+257
-71
lines changed

sdgym/synthesizers/uniform.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _sample_from_synthesizer(self, synthesizer, n_samples):
6060
for name, column in transformed.items():
6161
kind = column.dtype.kind
6262
if kind == 'i':
63-
values = np.random.randint(column.min(), column.max() + 1, size=n_samples)
63+
values = np.random.randint(int(column.min()), int(column.max()) + 1, size=n_samples)
6464
elif kind in ['O', 'b']:
6565
values = np.random.choice(column.unique(), size=n_samples)
6666
else:
@@ -71,7 +71,11 @@ def _sample_from_synthesizer(self, synthesizer, n_samples):
7171

7272

7373
class MultiTableUniformSynthesizer(BaselineSynthesizer):
74-
"""Synthesizer that uses UniformSynthesizer for multi-table data."""
74+
"""Multi-table Uniform Synthesizer.
75+
76+
This synthesizer trains a UniformSynthesizer on each table in the multi-table dataset.
77+
It samples data from each table independently using the corresponding trained synthesizer.
78+
"""
7579

7680
_MODALITY_FLAG = 'multi_table'
7781

@@ -80,7 +84,7 @@ def __init__(self):
8084
self.num_rows_per_table = {}
8185

8286
def _get_trained_synthesizer(self, data, metadata):
83-
"""This function should train single table UniformSynthesizers on each table in the data.
87+
"""Train a UniformSynthesizer for each table in the multi-table dataset.
8488
8589
Args:
8690
data (dict):

sdgym/synthesizers/utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,6 @@
11
"""Utility functions for synthesizers in SDGym."""
22

33
from sdgym.synthesizers.base import BaselineSynthesizer
4-
from sdgym.synthesizers.sdv import (
5-
_get_all_sdv_synthesizers,
6-
_validate_modality,
7-
)
8-
9-
10-
def _get_sdgym_synthesizers(modality):
11-
"""Get SDGym synthesizers.
12-
13-
Returns:
14-
list:
15-
A list of available SDGym synthesizer names.
16-
"""
17-
_validate_modality(modality)
18-
synthesizers = BaselineSynthesizer._get_supported_synthesizers(modality)
19-
sdv_synthesizer = _get_all_sdv_synthesizers()
20-
sdgym_synthesizer = [
21-
synthesizer for synthesizer in synthesizers if synthesizer not in sdv_synthesizer
22-
]
23-
return sorted(sdgym_synthesizer)
244

255

266
def get_available_single_table_synthesizers():

tests/integration/synthesizers/test_uniform.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import numpy as np
44
import pandas as pd
5+
from pandas.api.types import is_numeric_dtype
6+
from sdv.datasets.demo import download_demo
57

6-
from sdgym.synthesizers.uniform import UniformSynthesizer
8+
from sdgym.synthesizers import MultiTableUniformSynthesizer, UniformSynthesizer
79

810

911
def test_uniform_synthesizer():
@@ -69,3 +71,25 @@ def test_uniform_synthesizer():
6971

7072
assert n_values_interval2 * 0.9 < n_values_interval1 < n_values_interval2 * 1.1
7173
assert n_values_interval3 * 0.9 < n_values_interval1 < n_values_interval3 * 1.1
74+
75+
76+
def test_multitable_uniform_synthesizer_end_to_end():
77+
"""Test the MultiTableUniformSynthesizer end to end."""
78+
# Setup
79+
data, metadata = download_demo(dataset_name='fake_hotels', modality='multi_table')
80+
synthesizer = MultiTableUniformSynthesizer()
81+
82+
# Run
83+
trained_synthesizer = synthesizer.get_trained_synthesizer(data, metadata.to_dict())
84+
sampled_data = synthesizer.sample_from_synthesizer(trained_synthesizer, scale=2)
85+
86+
# Assert
87+
for table_name, table_data in data.items():
88+
sampled_table = sampled_data[table_name]
89+
assert len(sampled_table) == len(table_data) * 2
90+
for column_name in table_data.columns:
91+
original_column = table_data[column_name]
92+
sampled_column = sampled_table[column_name]
93+
if is_numeric_dtype(original_column):
94+
assert sampled_column.min() >= original_column.min()
95+
assert sampled_column.max() <= original_column.max()

tests/integration/synthesizers/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_get_available_single_table_synthesizers():
2828
def test_get_available_multi_table_synthesizers():
2929
"""Test the `get_available_multi_table_synthesizers` method"""
3030
# Setup
31-
expected_synthesizers = ['HMASynthesizer']
31+
expected_synthesizers = ['HMASynthesizer', 'MultiTableUniformSynthesizer']
3232

3333
# Run
3434
synthesizers = get_available_multi_table_synthesizers()

tests/unit/synthesizers/test_base.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99

1010
class TestBaselineSynthesizer:
1111
@patch('sdgym.synthesizers.utils.BaselineSynthesizer.get_subclasses')
12-
def test__get_supported_synthesizers_mock(self, mock_get_subclasses):
12+
@patch('sdgym.synthesizers.base._validate_modality')
13+
def test__get_supported_synthesizers_mock(self, mock_validate_modality, mock_get_subclasses):
1314
"""Test the `_get_supported_synthesizers` method with mocks."""
1415
# Setup
1516
mock_get_subclasses.return_value = {
16-
'Variant:ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=False),
17-
'Custom:MySynthesizer': Mock(_NATIVELY_SUPPORTED=False),
18-
'ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=True),
19-
'UniformSynthesizer': Mock(_NATIVELY_SUPPORTED=True),
20-
'DataIdentity': Mock(_NATIVELY_SUPPORTED=True),
17+
'Variant:Synthesizer': Mock(_NATIVELY_SUPPORTED=False, _MODALITY_FLAG='single_table'),
18+
'Custom:MySynthesizer': Mock(_NATIVELY_SUPPORTED=False, _MODALITY_FLAG='single_table'),
19+
'ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'),
20+
'UniformSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'),
21+
'MultiTableSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='multi_table'),
22+
'DataIdentity': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'),
2123
}
2224
expected_synthesizers = [
2325
'ColumnSynthesizer',
@@ -26,9 +28,11 @@ def test__get_supported_synthesizers_mock(self, mock_get_subclasses):
2628
]
2729

2830
# Run
29-
synthesizers = BaselineSynthesizer._get_supported_synthesizers()
31+
synthesizers = BaselineSynthesizer._get_supported_synthesizers('single_table')
3032

3133
# Assert
34+
mock_validate_modality.assert_called_once_with('single_table')
35+
mock_get_subclasses.assert_called_once_with(include_parents=True)
3236
assert synthesizers == expected_synthesizers
3337

3438
def test_get_trained_synthesizer(self):

tests/unit/synthesizers/test_generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_create_sdv_variant_synthesizer():
5555

5656
# Assert
5757
assert out.__name__ == 'Variant:test_synth'
58-
assert out.modality == 'single_table'
58+
assert out._MODALITY_FLAG == 'single_table'
5959
assert out._MODEL_KWARGS == synthesizer_parameters
6060
assert out.SDV_NAME == synthesizer_class
6161
assert out._NATIVELY_SUPPORTED is False
@@ -85,7 +85,7 @@ def test_create_sdv_variant_synthesizer_multi_table():
8585

8686
# Assert
8787
assert out.__name__ == 'Variant:test_synth'
88-
assert out.modality == 'multi_table'
88+
assert out._MODALITY_FLAG == 'multi_table'
8989
assert out._MODEL_KWARGS == synthesizer_parameters
9090
assert out.SDV_NAME == synthesizer_class
9191
assert out._NATIVELY_SUPPORTED is False

tests/unit/synthesizers/test_sdv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test__get_trained_synthesizer(mock_logger):
9999
synthesizer = Mock()
100100
synthesizer.__class__.__name__ = 'GaussianCopulaClass'
101101
synthesizer._MODEL_KWARGS = {'enforce_min_max_values': False}
102-
synthesizer.modality = 'single_table'
102+
synthesizer._MODALITY_FLAG = 'single_table'
103103
synthesizer.SDV_NAME = 'GaussianCopulaSynthesizer'
104104

105105
# Run
@@ -121,7 +121,7 @@ def test__sample_from_synthesizer(mock_logger):
121121
})
122122
base_synthesizer = Mock()
123123
base_synthesizer.__class__.__name__ = 'GaussianCopulaSynthesizer'
124-
base_synthesizer.modality = 'single_table'
124+
base_synthesizer._MODALITY_FLAG = 'single_table'
125125
synthesizer = Mock()
126126
synthesizer.sample.return_value = data
127127
n_samples = 3
@@ -187,7 +187,7 @@ def test__create_sdv_class_mock(mock_get_modality, mock_sys_modules):
187187

188188
# Assert
189189
assert synt_class.__name__ == sdv_name
190-
assert synt_class.modality == 'single_table'
190+
assert synt_class._MODALITY_FLAG == 'single_table'
191191
assert synt_class._MODEL_KWARGS == {}
192192
assert synt_class.SDV_NAME == sdv_name
193193
assert issubclass(synt_class, BaselineSynthesizer)
@@ -212,7 +212,7 @@ def test__create_sdv_class():
212212

213213
# Assert
214214
assert synthesizer_class.__name__ == sdv_name
215-
assert synthesizer_class.modality == 'single_table'
215+
assert synthesizer_class._MODALITY_FLAG == 'single_table'
216216
assert synthesizer_class._MODEL_KWARGS == {}
217217
assert issubclass(synthesizer_class, BaselineSynthesizer)
218218

0 commit comments

Comments
 (0)