Skip to content

Commit 0875fbc

Browse files
authored
Always include UniformSynthesizer (#440)
1 parent bfa27e3 commit 0875fbc

File tree

3 files changed

+219
-74
lines changed

3 files changed

+219
-74
lines changed

sdgym/benchmark.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
write_csv,
5353
write_file,
5454
)
55-
from sdgym.synthesizers import CTGANSynthesizer, GaussianCopulaSynthesizer
55+
from sdgym.synthesizers import CTGANSynthesizer, GaussianCopulaSynthesizer, UniformSynthesizer
5656
from sdgym.synthesizers.base import BaselineSynthesizer
5757
from sdgym.utils import (
5858
calculate_score_time,
@@ -66,7 +66,7 @@
6666
)
6767

6868
LOGGER = logging.getLogger(__name__)
69-
DEFAULT_SYNTHESIZERS = [GaussianCopulaSynthesizer, CTGANSynthesizer]
69+
DEFAULT_SYNTHESIZERS = [GaussianCopulaSynthesizer, CTGANSynthesizer, UniformSynthesizer]
7070
DEFAULT_DATASETS = [
7171
'adult',
7272
'alarm',
@@ -1045,6 +1045,12 @@ def _update_run_id_file(run_file, result_writer=None):
10451045
result_writer.write_yaml(update, run_file, append=True)
10461046

10471047

1048+
def _ensure_uniform_included(synthesizers):
1049+
if UniformSynthesizer not in synthesizers and UniformSynthesizer.__name__ not in synthesizers:
1050+
LOGGER.info('Adding UniformSynthesizer to list of synthesizers.')
1051+
synthesizers.append(UniformSynthesizer)
1052+
1053+
10481054
def benchmark_single_table(
10491055
synthesizers=DEFAULT_SYNTHESIZERS,
10501056
custom_synthesizers=None,
@@ -1146,6 +1152,9 @@ def benchmark_single_table(
11461152
output_filepath, detailed_results_folder, multi_processing_config, run_on_ec2
11471153
)
11481154
_validate_output_destination(output_destination)
1155+
if not synthesizers:
1156+
synthesizers = []
1157+
_ensure_uniform_included(synthesizers)
11491158
result_writer = LocalResultsWriter()
11501159
if run_on_ec2:
11511160
print("This will create an instance for the current AWS user's account.") # noqa
@@ -1448,6 +1457,9 @@ def benchmark_single_table_aws(
14481457
'aws_secret_access_key': aws_secret_access_key,
14491458
},
14501459
)
1460+
if not synthesizers:
1461+
synthesizers = []
1462+
_ensure_uniform_included(synthesizers)
14511463
job_args_list = _generate_job_args_list(
14521464
limit_dataset_size=limit_dataset_size,
14531465
sdv_datasets=sdv_datasets,

tests/integration/test_benchmark.py

Lines changed: 95 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -301,21 +301,24 @@ def test_benchmark_single_table_timeout():
301301

302302
# Assert
303303
assert total_time < 50.0 # Buffer time for code not in timeout
304-
expected_scores = pd.DataFrame({
305-
'Synthesizer': {0: 'GaussianCopulaSynthesizer'},
306-
'Dataset': {0: 'insurance'},
307-
'Dataset_Size_MB': {0: 3.340128},
308-
'Train_Time': {0: None},
309-
'Peak_Memory_MB': {0: None},
310-
'Synthesizer_Size_MB': {0: None},
311-
'Sample_Time': {0: None},
312-
'Evaluate_Time': {0: None},
313-
'Diagnostic_Score': {0: None},
314-
'Quality_Score': {0: None},
315-
'Privacy_Score': {0: None},
316-
'error': {0: 'Synthesizer Timeout'},
317-
})
318-
pd.testing.assert_frame_equal(scores, expected_scores)
304+
timeout_scores = pd.Series(
305+
{
306+
'Synthesizer': 'GaussianCopulaSynthesizer',
307+
'Dataset': 'insurance',
308+
'Dataset_Size_MB': 3.340128,
309+
'Train_Time': None,
310+
'Peak_Memory_MB': None,
311+
'Synthesizer_Size_MB': None,
312+
'Sample_Time': None,
313+
'Evaluate_Time': None,
314+
'Diagnostic_Score': None,
315+
'Quality_Score': None,
316+
'Privacy_Score': None,
317+
'error': 'Synthesizer Timeout',
318+
},
319+
name=0,
320+
)
321+
pd.testing.assert_series_equal(scores.T[0], timeout_scores)
319322

320323

321324
def test_benchmark_single_table_only_datasets():
@@ -332,18 +335,23 @@ def test_benchmark_single_table_only_datasets():
332335

333336
# Assert
334337
assert len(scores.columns) == 12
335-
assert list(scores['Synthesizer']) == ['GaussianCopulaSynthesizer', 'CTGANSynthesizer']
336-
assert list(scores['Dataset']) == ['fake_companies'] * 2
337-
assert [round(score, 5) for score in scores['Dataset_Size_MB']] == [0.00128] * 2
338+
assert list(scores['Synthesizer']) == [
339+
'GaussianCopulaSynthesizer',
340+
'CTGANSynthesizer',
341+
'UniformSynthesizer',
342+
]
343+
assert list(scores['Dataset']) == ['fake_companies'] * 3
344+
assert [round(score, 5) for score in scores['Dataset_Size_MB']] == [0.00128] * 3
338345
assert scores['Train_Time'].between(0, 1000).all()
339346
assert scores['Peak_Memory_MB'].between(0, 1000).all()
340347
assert scores['Synthesizer_Size_MB'].between(0, 1000).all()
341348
assert scores['Sample_Time'].between(0, 1000).all()
342349
assert scores['Evaluate_Time'].between(0, 1000).all()
343350
assert scores['Quality_Score'].between(0.5, 1).all()
344351
assert scores['Privacy_Score'].between(0.5, 1).all()
345-
assert (scores['Diagnostic_Score'] == 1.0).all()
346-
assert list(scores['NewRowSynthesis']) == [1.0] * 2
352+
assert (scores['Diagnostic_Score'][0:2] == 1.0).all()
353+
assert scores['Diagnostic_Score'][2:].between(0.5, 1.0).all()
354+
assert list(scores['NewRowSynthesis']) == [1.0] * 3
347355

348356

349357
def test_benchmark_single_table_synthesizers_none():
@@ -361,50 +369,62 @@ def test_benchmark_single_table_synthesizers_none():
361369
)
362370

363371
# Assert
364-
assert scores.shape == (1, 11)
365-
scores = scores.iloc[0]
366-
assert scores['Synthesizer'] == 'Variant:test_synth'
367-
assert scores['Dataset'] == 'fake_companies'
368-
assert round(scores['Dataset_Size_MB'], 5) == 0.00128
369-
assert 0.5 < scores['Quality_Score'] < 1
370-
assert 0.5 < scores['Privacy_Score'] <= 1.0
371-
assert scores['Diagnostic_Score'] == 1.0
372-
assert (
373-
scores[
374-
['Train_Time', 'Peak_Memory_MB', 'Synthesizer_Size_MB', 'Sample_Time', 'Evaluate_Time']
375-
]
376-
.between(0, 1000)
377-
.all()
378-
)
372+
assert scores.shape == (2, 11)
373+
for name, iloc in (('UniformSynthesizer', 0), ('Variant:test_synth', 1)):
374+
_scores = scores.iloc[iloc]
375+
assert _scores['Synthesizer'] == name
376+
assert _scores['Dataset'] == 'fake_companies'
377+
assert round(_scores['Dataset_Size_MB'], 5) == 0.00128
378+
assert 0.5 < _scores['Quality_Score'] < 1
379+
assert 0.5 < _scores['Privacy_Score'] <= 1.0
380+
if name == 'Variant:test_synth':
381+
assert _scores['Diagnostic_Score'] == 1.0
382+
else:
383+
assert 0.5 < _scores['Diagnostic_Score'] <= 1.0
384+
assert (
385+
_scores[
386+
[
387+
'Train_Time',
388+
'Peak_Memory_MB',
389+
'Synthesizer_Size_MB',
390+
'Sample_Time',
391+
'Evaluate_Time',
392+
]
393+
]
394+
.between(0, 1000)
395+
.all()
396+
)
379397

380398

381399
def test_benchmark_single_table_no_synthesizers():
382400
"""Test it works when no synthesizers are passed.
383401
384-
It should return an empty dataframe.
402+
It should still run UniformSynthesizer.
385403
"""
386404
# Run
387405
result = benchmark_single_table(
388406
synthesizers=None,
407+
sdv_datasets=['fake_companies'],
389408
sdmetrics=[('NewRowSynthesis', {'synthetic_sample_size': 1000})],
390409
)
391410

392411
# Assert
393-
expected = pd.DataFrame({
394-
'Synthesizer': [],
395-
'Dataset': [],
396-
'Dataset_Size_MB': [],
397-
'Train_Time': [],
398-
'Peak_Memory_MB': [],
399-
'Synthesizer_Size_MB': [],
400-
'Sample_Time': [],
401-
'Evaluate_Time': [],
402-
'Diagnostic_Score': [],
403-
'Quality_Score': [],
404-
'Privacy_Score': [],
405-
'NewRowSynthesis': [],
406-
})
407-
pd.testing.assert_frame_equal(result, expected)
412+
assert result.shape == (1, 12)
413+
result = result.iloc[0]
414+
assert result['Synthesizer'] == 'UniformSynthesizer'
415+
assert result['Dataset'] == 'fake_companies'
416+
assert round(result['Dataset_Size_MB'], 5) == 0.00128
417+
assert 0.5 < result['Quality_Score'] < 1
418+
assert 0.5 < result['Privacy_Score'] <= 1.0
419+
assert 0.5 < result['Diagnostic_Score'] <= 1.0
420+
assert 0 < result['NewRowSynthesis'] <= 1.0
421+
assert (
422+
result[
423+
['Train_Time', 'Peak_Memory_MB', 'Synthesizer_Size_MB', 'Sample_Time', 'Evaluate_Time']
424+
]
425+
.between(0, 1000)
426+
.all()
427+
)
408428

409429

410430
def test_benchmark_single_table_no_datasets():
@@ -449,19 +469,18 @@ def test_benchmark_single_table_no_synthesizers_with_parameters():
449469
)
450470

451471
# Assert
452-
expected = pd.DataFrame({
453-
'Synthesizer': [],
454-
'Dataset': [],
455-
'Dataset_Size_MB': [],
456-
'Train_Time': [],
457-
'Peak_Memory_MB': [],
458-
'Synthesizer_Size_MB': [],
459-
'Sample_Time': [],
460-
'Evaluate_Time': [],
461-
'a': [],
462-
'b': [],
463-
})
464-
pd.testing.assert_frame_equal(result, expected)
472+
assert result.shape == (1, 9)
473+
result = result.iloc[0]
474+
assert result['Synthesizer'] == 'UniformSynthesizer'
475+
assert result['Dataset'] == 'fake_companies'
476+
assert round(result['Dataset_Size_MB'], 5) == 0.00128
477+
assert (
478+
result[['Train_Time', 'Peak_Memory_MB', 'Synthesizer_Size_MB', 'Sample_Time']]
479+
.between(0, 1000)
480+
.all()
481+
)
482+
assert result['Evaluate_Time'] is None
483+
assert result['error'] == 'ValueError: Unknown single-table metric: a'
465484

466485

467486
def test_benchmark_single_table_custom_synthesizer():
@@ -489,7 +508,7 @@ def sample_from_synthesizer(synthesizer, n_samples):
489508
)
490509

491510
# Assert
492-
results = results.iloc[0]
511+
results = results.iloc[1]
493512
assert results['Synthesizer'] == 'Custom:TestSynthesizer'
494513
assert results['Dataset'] == 'fake_companies'
495514
assert round(results['Dataset_Size_MB'], 5) == 0.00128
@@ -563,7 +582,7 @@ def sample_from_synthesizer(synthesizer, n_samples):
563582
)
564583

565584
# Assert
566-
results = results.iloc[0]
585+
results = results.iloc[1]
567586
assert results['Synthesizer'] == 'Custom:TestSynthesizer'
568587
assert results['Dataset'] == 'fake_companies'
569588
assert round(results['Dataset_Size_MB'], 5) == 0.00128
@@ -622,7 +641,11 @@ def test_benchmark_single_table_with_output_destination(tmp_path):
622641
synthesizer_directions = os.listdir(
623642
os.path.join(output_destination, directions[0], f'fake_companies_{today_date}')
624643
)
625-
assert set(synthesizer_directions) == {'TVAESynthesizer', 'GaussianCopulaSynthesizer'}
644+
assert set(synthesizer_directions) == {
645+
'TVAESynthesizer',
646+
'GaussianCopulaSynthesizer',
647+
'UniformSynthesizer',
648+
}
626649
for synthesizer in sorted(synthesizer_directions):
627650
synthesizer_files = os.listdir(
628651
os.path.join(
@@ -698,7 +721,11 @@ def test_benchmark_single_table_with_output_destination_multiple_runs(tmp_path):
698721
synthesizer_directions = os.listdir(
699722
os.path.join(output_destination, directions[0], f'fake_companies_{today_date}')
700723
)
701-
assert set(synthesizer_directions) == {'TVAESynthesizer', 'GaussianCopulaSynthesizer'}
724+
assert set(synthesizer_directions) == {
725+
'TVAESynthesizer',
726+
'GaussianCopulaSynthesizer',
727+
'UniformSynthesizer',
728+
}
702729
for synthesizer in sorted(synthesizer_directions):
703730
synthesizer_files = os.listdir(
704731
os.path.join(

0 commit comments

Comments
 (0)