Skip to content

Commit 561b1cb

Browse files
committed
Address comments and add more unit tests
1 parent 7e2a5d3 commit 561b1cb

File tree

5 files changed

+270
-97
lines changed

5 files changed

+270
-97
lines changed

sdgym/benchmark.py

Lines changed: 96 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111
SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS
112112

113113

114-
def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers):
114+
def _validate_output_filepath_and_detailed_results_folder(output_filepath, detailed_results_folder):
115115
if output_filepath and os.path.exists(output_filepath):
116116
raise ValueError(
117117
f'{output_filepath} already exists. Please provide a file that does not already exist.'
@@ -123,15 +123,65 @@ def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, cus
123123
'Please provide a folder that does not already exist.'
124124
)
125125

126-
duplicates = get_duplicates(synthesizers) if synthesizers else set()
127-
if custom_synthesizers:
128-
duplicates.update(get_duplicates(custom_synthesizers))
129-
if len(duplicates) > 0:
126+
127+
def _import_and_validate_synthesizers(synthesizers, custom_synthesizers, modality):
128+
"""Import user-provided synthesizer and validate modality and uniqueness.
129+
130+
This function takes lists of synthesizer, imports them as synthesizer classes,
131+
and validates two conditions:
132+
- Modality match – all synthesizers must match the expected `modality`.
133+
A `ValueError` is raised if any synthesizer has a different modality
134+
flag.
135+
136+
- Uniqueness – duplicate synthesizer across the two input lists
137+
(`synthesizers` and `custom_synthesizers`) are not allowed. A
138+
`ValueError` is raised if duplicates are found.
139+
140+
Args:
141+
synthesizers (list | None):
142+
A list of synthesizer strings or classes. May be ``None``, in which case it
143+
is treated as an empty list.
144+
custom_synthesizers (list | None):
145+
A list of custom synthesizer.
146+
modality (str):
147+
The required modality that all synthesizers must match.
148+
149+
Returns:
150+
list:
151+
A list of synthesizer classes.
152+
153+
Raises:
154+
ValueError:
155+
If any synthesizer does not match the expected modality.
156+
ValueError:
157+
If duplicate synthesizer are found across the provided lists.
158+
"""
159+
# Get list of synthesizer objects
160+
synthesizers = synthesizers or []
161+
custom_synthesizers = custom_synthesizers or []
162+
resolved_synthesizers = get_synthesizers(synthesizers + custom_synthesizers)
163+
mismatched = [
164+
synth['synthesizer']
165+
for synth in resolved_synthesizers
166+
if synth['synthesizer']._MODALITY_FLAG != modality
167+
]
168+
if mismatched:
169+
raise ValueError(
170+
f"Synthesizers must be of modality '{modality}'. "
171+
"Found this synthesizers that don't match: "
172+
f'{", ".join([type(synth).__name__ for synth in mismatched])}'
173+
)
174+
175+
# Check duplicate input values
176+
duplicates = get_duplicates(synthesizers + custom_synthesizers)
177+
if duplicates:
130178
raise ValueError(
131179
'Synthesizers must be unique. Please remove repeated values in the `synthesizers` '
132180
'and `custom_synthesizers` parameters.'
133181
)
134182

183+
return resolved_synthesizers
184+
135185

136186
def _create_detailed_results_directory(detailed_results_folder):
137187
if detailed_results_folder and not is_s3_path(detailed_results_folder):
@@ -276,15 +326,9 @@ def _generate_job_args_list(
276326
compute_diagnostic_score,
277327
compute_privacy_score,
278328
synthesizers,
279-
custom_synthesizers,
280329
s3_client,
281330
modality,
282331
):
283-
# Get list of synthesizer objects
284-
synthesizers = [] if synthesizers is None else synthesizers
285-
custom_synthesizers = [] if custom_synthesizers is None else custom_synthesizers
286-
synthesizers = get_synthesizers(synthesizers + custom_synthesizers)
287-
288332
# Get list of dataset paths
289333
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
290334
aws_secret_access_key_key = os.getenv('AWS_SECRET_ACCESS_KEY')
@@ -427,7 +471,7 @@ def _compute_scores(
427471
sdmetrics_metadata = metadata
428472

429473
if len(metrics) > 0:
430-
metrics, metric_kwargs = get_metrics(metrics, modality=modality.replace('_', '-'))
474+
metrics, metric_kwargs = get_metrics(metrics, modality=modality)
431475
scores = []
432476
output['scores'] = scores
433477
for metric_name, metric in metrics.items():
@@ -1130,8 +1174,8 @@ def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=No
11301174
}
11311175

11321176
for synthesizer in synthesizers:
1133-
if synthesizer not in SDV_SYNTHESIZERS:
1134-
ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY.get(synthesizer)
1177+
if synthesizer['name'] not in SDV_SYNTHESIZERS:
1178+
ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY.get(synthesizer['name'])
11351179
if ext_lib:
11361180
library_version = version(ext_lib)
11371181
metadata[f'{ext_lib}_version'] = library_version
@@ -1150,20 +1194,17 @@ def _update_metainfo_file(run_file, result_writer=None):
11501194
result_writer.write_yaml(update, run_file, append=True)
11511195

11521196

1153-
def _ensure_uniform_included(synthesizers):
1154-
if UniformSynthesizer not in synthesizers and UniformSynthesizer.__name__ not in synthesizers:
1155-
LOGGER.info('Adding UniformSynthesizer to list of synthesizers.')
1156-
synthesizers.append('UniformSynthesizer')
1157-
1197+
def _ensure_uniform_included(synthesizers, modality):
1198+
uniform_class = UniformSynthesizer
1199+
if modality == 'multi_table':
1200+
uniform_class = MultiTableUniformSynthesizer
11581201

1159-
def _ensure_multi_table_uniform_is_included(synthesizers):
11601202
uniform_not_included = bool(
1161-
MultiTableUniformSynthesizer not in synthesizers
1162-
and MultiTableUniformSynthesizer.__name__ not in synthesizers
1203+
uniform_class not in synthesizers and uniform_class.__name__ not in synthesizers
11631204
)
11641205
if uniform_not_included:
1165-
LOGGER.info('Adding MultiTableUniformSynthesizer to the list of synthesizers.')
1166-
synthesizers.append('MultiTableUniformSynthesizer')
1206+
LOGGER.info(f'Adding {uniform_class.__name__} to the list of synthesizers.')
1207+
synthesizers.append(uniform_class.__name__)
11671208

11681209

11691210
def _fill_adjusted_scores_with_none(scores):
@@ -1331,7 +1372,7 @@ def benchmark_single_table(
13311372
if not synthesizers:
13321373
synthesizers = []
13331374

1334-
_ensure_uniform_included(synthesizers)
1375+
_ensure_uniform_included(synthesizers, 'single_table')
13351376
result_writer = LocalResultsWriter()
13361377
if run_on_ec2:
13371378
print("This will create an instance for the current AWS user's account.") # noqa
@@ -1343,21 +1384,25 @@ def benchmark_single_table(
13431384

13441385
return None
13451386

1346-
_validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers)
1347-
_create_detailed_results_directory(detailed_results_folder)
1348-
job_args_list = _generate_job_args_list(
1349-
limit_dataset_size,
1350-
sdv_datasets,
1351-
additional_datasets_folder,
1352-
sdmetrics,
1353-
detailed_results_folder,
1354-
timeout,
1355-
output_destination,
1356-
compute_quality_score,
1357-
compute_diagnostic_score,
1358-
compute_privacy_score,
1387+
_validate_output_filepath_and_detailed_results_folder(output_filepath, detailed_results_folder)
1388+
synthesizers = _import_and_validate_synthesizers(
13591389
synthesizers,
13601390
custom_synthesizers,
1391+
'single_table',
1392+
)
1393+
_create_detailed_results_directory(detailed_results_folder)
1394+
job_args_list = _generate_job_args_list(
1395+
limit_dataset_size=limit_dataset_size,
1396+
sdv_datasets=sdv_datasets,
1397+
additional_datasets_folder=additional_datasets_folder,
1398+
sdmetrics=sdmetrics,
1399+
detailed_results_folder=detailed_results_folder,
1400+
timeout=timeout,
1401+
output_destination=output_destination,
1402+
compute_quality_score=compute_quality_score,
1403+
compute_diagnostic_score=compute_diagnostic_score,
1404+
compute_privacy_score=compute_privacy_score,
1405+
synthesizers=synthesizers,
13611406
s3_client=None,
13621407
modality='single_table',
13631408
)
@@ -1650,7 +1695,13 @@ def benchmark_single_table_aws(
16501695
if not synthesizers:
16511696
synthesizers = []
16521697

1653-
_ensure_uniform_included(synthesizers)
1698+
_ensure_uniform_included(synthesizers, 'single_table')
1699+
synthesizers = _import_and_validate_synthesizers(
1700+
synthesizers=synthesizers,
1701+
custom_synthesizers=None,
1702+
modality='single_table',
1703+
)
1704+
16541705
job_args_list = _generate_job_args_list(
16551706
limit_dataset_size=limit_dataset_size,
16561707
sdv_datasets=sdv_datasets,
@@ -1663,7 +1714,6 @@ def benchmark_single_table_aws(
16631714
compute_privacy_score=compute_privacy_score,
16641715
synthesizers=synthesizers,
16651716
detailed_results_folder=None,
1666-
custom_synthesizers=None,
16671717
s3_client=s3_client,
16681718
modality='single_table',
16691719
)
@@ -1745,20 +1795,19 @@ def benchmark_multi_table(
17451795
17461796
Returns:
17471797
pandas.DataFrame:
1748-
A table containing one row per synthesizer + dataset + metric.
1798+
A table containing one row per synthesizer + dataset.
17491799
"""
17501800
_validate_output_destination(output_destination)
17511801
if not synthesizers:
17521802
synthesizers = []
17531803

1754-
_ensure_multi_table_uniform_is_included(synthesizers)
1804+
_ensure_uniform_included(synthesizers, 'multi_table')
17551805
result_writer = LocalResultsWriter()
17561806

1757-
_validate_inputs(
1758-
output_filepath=None,
1759-
detailed_results_folder=None,
1760-
synthesizers=synthesizers,
1761-
custom_synthesizers=custom_synthesizers,
1807+
synthesizers = _import_and_validate_synthesizers(
1808+
synthesizers,
1809+
custom_synthesizers,
1810+
'multi_table',
17621811
)
17631812
job_args_list = _generate_job_args_list(
17641813
limit_dataset_size=limit_dataset_size,
@@ -1772,7 +1821,6 @@ def benchmark_multi_table(
17721821
compute_diagnostic_score=compute_diagnostic_score,
17731822
compute_privacy_score=None,
17741823
synthesizers=synthesizers,
1775-
custom_synthesizers=custom_synthesizers,
17761824
s3_client=None,
17771825
modality='multi_table',
17781826
)

sdgym/metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ def normalize(self, raw_score):
8080
],
8181
}
8282
DATA_MODALITY_METRICS = {
83-
'single-table': [
83+
'single_table': [
8484
'CSTest',
8585
'KSComplement',
8686
],
87-
'multi-table': [
87+
'multi_table': [
8888
'CSTest',
8989
'KSComplement',
9090
],
91-
'timeseries': [
91+
'sequential': [
9292
'TSFClassifierEfficacy',
9393
'LSTMClassifierEfficacy',
9494
'TSFCDetection',
@@ -104,17 +104,17 @@ def get_metrics(metrics, modality):
104104
metrics (list):
105105
List of strings or tuples ``(metric, metric_args)`` describing the metrics.
106106
modality (str):
107-
It must be ``'single-table'``, ``'multi-table'`` or ``'timeseries'``.
107+
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
108108
109109
Returns:
110110
list, kwargs:
111111
A list of metrics for the given modality, and their corresponding kwargs.
112112
"""
113-
if modality == 'multi-table':
113+
if modality == 'multi_table':
114114
metric_classes = sdmetrics.multi_table.MultiTableMetric.get_subclasses()
115-
elif modality == 'single-table':
115+
elif modality == 'single_table':
116116
metric_classes = sdmetrics.single_table.SingleTableMetric.get_subclasses()
117-
elif modality == 'timeseries':
117+
elif modality == 'sequential':
118118
metric_classes = sdmetrics.timeseries.TimeSeriesMetric.get_subclasses()
119119

120120
if not metrics:

tests/integration/test_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def test_benchmark_single_table_no_synthesizers_with_parameters():
510510
.all()
511511
)
512512
assert result['Evaluate_Time'] is None
513-
assert result['error'] == 'ValueError: Unknown single-table metric: a'
513+
assert result['error'] == 'ValueError: Unknown single_table metric: a'
514514

515515

516516
def test_benchmark_single_table_custom_synthesizer():

tests/unit/test__dataset_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test__get_dataset_subset_single_table():
7878
metadata = {'tables': {'table': {'columns': {f'c{i}': {} for i in range(15)}}}}
7979

8080
# Run
81-
result_df, result_meta = _get_dataset_subset(df, metadata, modality='regular')
81+
result_df, result_meta = _get_dataset_subset(df, metadata, modality='single_table')
8282

8383
# Assert
8484
assert len(result_df) <= 1000
@@ -162,7 +162,7 @@ def test__read_zipped_data_single(mock_read):
162162

163163
# Run
164164
with patch('sdgym._dataset_utils.ZipFile', return_value=mock_zip):
165-
data_single = _read_zipped_data('fake.zip', modality='single')
165+
data_single = _read_zipped_data('fake.zip', modality='single_table')
166166

167167
# Assert
168168
assert isinstance(data_single, pd.DataFrame)

0 commit comments

Comments
 (0)