Skip to content

Commit f4ae85b

Browse files
committed
cleaning
1 parent 5f7103f commit f4ae85b

File tree

3 files changed

+68
-10
lines changed

3 files changed

+68
-10
lines changed

sdgym/benchmark.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
'TVAESynthesizer',
108108
]
109109
SDV_MULTI_TABLE_SYNTHESIZERS = ['HMASynthesizer']
110-
110+
MODALITY_IDX = 10
111111
SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS
112112

113113

@@ -1486,7 +1486,7 @@ def _get_s3_script_content(
14861486
return f"""
14871487
import boto3
14881488
import cloudpickle
1489-
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file
1489+
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file, MODALITY_IDX
14901490
from io import StringIO
14911491
from sdgym.result_writer import S3ResultsWriter
14921492
@@ -1498,8 +1498,9 @@ def _get_s3_script_content(
14981498
)
14991499
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
15001500
job_args_list = cloudpickle.loads(response['Body'].read())
1501+
modality = job_args_list[0][MODALITY_IDX]
15011502
result_writer = S3ResultsWriter(s3_client=s3_client)
1502-
_write_metainfo_file({synthesizers}, job_args_list, 'single_table', result_writer)
1503+
_write_metainfo_file({synthesizers}, job_args_list, modality, result_writer)
15031504
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
15041505
metainfo_filename = job_args_list[0][-1]['metainfo']
15051506
_update_metainfo_file(metainfo_filename, result_writer)
@@ -1876,15 +1877,13 @@ def benchmark_multi_table_aws(
18761877
Whether or not to evaluate an overall quality score. Defaults to ``True``.
18771878
compute_diagnostic_score (bool):
18781879
Whether or not to evaluate an overall diagnostic score. Defaults to ``True``.
1879-
compute_privacy_score (bool):
1880-
Whether or not to evaluate an overall privacy score. Defaults to ``True``.
18811880
timeout (int or ``None``):
18821881
The maximum number of seconds to wait for synthetic data creation. If ``None``, no
18831882
timeout is enforced.
18841883
18851884
Returns:
18861885
pandas.DataFrame:
1887-
A table containing one row per synthesizer + dataset + metric.
1886+
A table containing one row per synthesizer + dataset.
18881887
"""
18891888
s3_client = _validate_output_destination(
18901889
output_destination,
@@ -1901,17 +1900,17 @@ def benchmark_multi_table_aws(
19011900
limit_dataset_size=limit_dataset_size,
19021901
sdv_datasets=sdv_datasets,
19031902
additional_datasets_folder=additional_datasets_folder,
1903+
sdmetrics=None,
19041904
timeout=timeout,
19051905
output_destination=output_destination,
19061906
compute_quality_score=compute_quality_score,
19071907
compute_diagnostic_score=compute_diagnostic_score,
1908+
compute_privacy_score=None,
19081909
synthesizers=synthesizers,
19091910
detailed_results_folder=None,
19101911
custom_synthesizers=None,
19111912
s3_client=s3_client,
19121913
modality='multi_table',
1913-
sdmetrics=None,
1914-
compute_privacy_score=None,
19151914
)
19161915
if not job_args_list:
19171916
return _get_empty_dataframe(

sdgym/synthesizers/sdv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
MODEL_KWARGS = {'HMASynthesizer': {'verbose': False}}
2323

24-
MODEL_KWARGS = {'HMASynthesizer': {'verbose': False}}
25-
2624

2725
def _get_sdv_synthesizers(modality):
2826
_validate_modality(modality)

tests/unit/test_benchmark.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_validate_output_destination,
3030
_write_metainfo_file,
3131
benchmark_multi_table,
32+
benchmark_multi_table_aws,
3233
benchmark_single_table,
3334
benchmark_single_table_aws,
3435
)
@@ -1298,3 +1299,63 @@ def test_benchmark_multi_table_no_jobs(
12981299
)
12991300
mock__write_metainfo_file.assert_called_once()
13001301
pd.testing.assert_frame_equal(scores, empty_scores)
1302+
1303+
1304+
@patch('sdgym.benchmark._validate_output_destination')
1305+
@patch('sdgym.benchmark._generate_job_args_list')
1306+
@patch('sdgym.benchmark._run_on_aws')
1307+
def test_benchmark_multi_table_aws(
1308+
mock_run_on_aws, mock_generate_job_args_list, mock_validate_output_destination
1309+
):
1310+
"""Test `benchmark_multi_table_aws` method."""
1311+
# Setup
1312+
output_destination = 's3://sdgym-benchmark/Debug/Issue_487_test_1'
1313+
synthesizers = ['HMASynthesizer']
1314+
datasets = ['financial', 'NBA']
1315+
aws_access_key_id = '12345'
1316+
aws_secret_access_key = '67890'
1317+
mock_validate_output_destination.return_value = 's3_client_mock'
1318+
mock_generate_job_args_list.return_value = 'job_args_list_mock'
1319+
1320+
# Run
1321+
benchmark_multi_table_aws(
1322+
output_destination=output_destination,
1323+
aws_access_key_id=aws_access_key_id,
1324+
aws_secret_access_key=aws_secret_access_key,
1325+
synthesizers=synthesizers,
1326+
sdv_datasets=datasets,
1327+
)
1328+
1329+
# Assert
1330+
assert 'MultiTableUniformSynthesizer' in synthesizers
1331+
mock_validate_output_destination.assert_called_once_with(
1332+
output_destination,
1333+
aws_keys={
1334+
'aws_access_key_id': aws_access_key_id,
1335+
'aws_secret_access_key': aws_secret_access_key,
1336+
},
1337+
)
1338+
mock_generate_job_args_list.assert_called_once_with(
1339+
limit_dataset_size=False,
1340+
sdv_datasets=datasets,
1341+
additional_datasets_folder=None,
1342+
sdmetrics=None,
1343+
timeout=None,
1344+
output_destination=output_destination,
1345+
compute_quality_score=True,
1346+
compute_diagnostic_score=True,
1347+
compute_privacy_score=None,
1348+
synthesizers=synthesizers,
1349+
detailed_results_folder=None,
1350+
custom_synthesizers=None,
1351+
s3_client='s3_client_mock',
1352+
modality='multi_table',
1353+
)
1354+
mock_run_on_aws.assert_called_once_with(
1355+
output_destination=output_destination,
1356+
synthesizers=synthesizers,
1357+
s3_client='s3_client_mock',
1358+
job_args_list='job_args_list_mock',
1359+
aws_access_key_id='12345',
1360+
aws_secret_access_key='67890',
1361+
)

0 commit comments

Comments
 (0)