Skip to content

Commit 6a38fbc

Browse files
authored
Always include UniformSynthesizer doesn't work on AWS (#447)
1 parent 0dde691 commit 6a38fbc

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

sdgym/benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ def _update_run_id_file(run_file, result_writer=None):
10511051
def _ensure_uniform_included(synthesizers):
10521052
if UniformSynthesizer not in synthesizers and UniformSynthesizer.__name__ not in synthesizers:
10531053
LOGGER.info('Adding UniformSynthesizer to list of synthesizers.')
1054-
synthesizers.append(UniformSynthesizer)
1054+
synthesizers.append('UniformSynthesizer')
10551055

10561056

10571057
def _add_adjusted_scores(scores, timeout):
@@ -1201,6 +1201,7 @@ def benchmark_single_table(
12011201
_validate_output_destination(output_destination)
12021202
if not synthesizers:
12031203
synthesizers = []
1204+
12041205
_ensure_uniform_included(synthesizers)
12051206
result_writer = LocalResultsWriter()
12061207
if run_on_ec2:
@@ -1506,6 +1507,7 @@ def benchmark_single_table_aws(
15061507
)
15071508
if not synthesizers:
15081509
synthesizers = []
1510+
15091511
_ensure_uniform_included(synthesizers)
15101512
job_args_list = _generate_job_args_list(
15111513
limit_dataset_size=limit_dataset_size,

tests/unit/test_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test__ensure_uniform_included_adds_uniform(caplog):
230230
_ensure_uniform_included(synthesizers)
231231

232232
# Assert
233-
assert synthesizers == [GaussianCopulaSynthesizer, UniformSynthesizer]
233+
assert synthesizers == [GaussianCopulaSynthesizer, 'UniformSynthesizer']
234234
assert any(expected_message in record.message for record in caplog.records)
235235

236236

@@ -701,7 +701,7 @@ def test_benchmark_single_table_aws(
701701
)
702702

703703
# Assert
704-
assert UniformSynthesizer in synthesizers
704+
assert 'UniformSynthesizer' in synthesizers
705705
mock_validate_output_destination.assert_called_once_with(
706706
output_destination,
707707
aws_keys={
@@ -777,14 +777,14 @@ def test_benchmark_single_table_aws_synthesizers_none(
777777
compute_quality_score=True,
778778
compute_diagnostic_score=True,
779779
compute_privacy_score=True,
780-
synthesizers=[UniformSynthesizer],
780+
synthesizers=['UniformSynthesizer'],
781781
detailed_results_folder=None,
782782
custom_synthesizers=None,
783783
s3_client='s3_client_mock',
784784
)
785785
mock_run_on_aws.assert_called_once_with(
786786
output_destination=output_destination,
787-
synthesizers=[UniformSynthesizer],
787+
synthesizers=['UniformSynthesizer'],
788788
s3_client='s3_client_mock',
789789
job_args_list='job_args_list_mock',
790790
aws_access_key_id='12345',

0 commit comments

Comments
 (0)