Skip to content

Commit 684a40f

Browse files
committed
cleaning
1 parent 7624b2c commit 684a40f

File tree

3 files changed

+68
-10
lines changed

3 files changed

+68
-10
lines changed

sdgym/benchmark.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
'TVAESynthesizer',
108108
]
109109
SDV_MULTI_TABLE_SYNTHESIZERS = ['HMASynthesizer']
110-
110+
MODALITY_IDX = 10
111111
SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS
112112

113113

@@ -1587,7 +1587,7 @@ def _get_s3_script_content(
15871587
return f"""
15881588
import boto3
15891589
import cloudpickle
1590-
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file
1590+
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file, MODALITY_IDX
15911591
from io import StringIO
15921592
from sdgym.result_writer import S3ResultsWriter
15931593
@@ -1599,8 +1599,9 @@ def _get_s3_script_content(
15991599
)
16001600
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
16011601
job_args_list = cloudpickle.loads(response['Body'].read())
1602+
modality = job_args_list[0][MODALITY_IDX]
16021603
result_writer = S3ResultsWriter(s3_client=s3_client)
1603-
_write_metainfo_file({synthesizers}, job_args_list, 'single_table', result_writer)
1604+
_write_metainfo_file({synthesizers}, job_args_list, modality, result_writer)
16041605
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
16051606
metainfo_filename = job_args_list[0][-1]['metainfo']
16061607
_update_metainfo_file(metainfo_filename, result_writer)
@@ -1977,15 +1978,13 @@ def benchmark_multi_table_aws(
19771978
Whether or not to evaluate an overall quality score. Defaults to ``True``.
19781979
compute_diagnostic_score (bool):
19791980
Whether or not to evaluate an overall diagnostic score. Defaults to ``True``.
1980-
compute_privacy_score (bool):
1981-
Whether or not to evaluate an overall privacy score. Defaults to ``True``.
19821981
timeout (int or ``None``):
19831982
The maximum number of seconds to wait for synthetic data creation. If ``None``, no
19841983
timeout is enforced.
19851984
19861985
Returns:
19871986
pandas.DataFrame:
1988-
A table containing one row per synthesizer + dataset + metric.
1987+
A table containing one row per synthesizer + dataset.
19891988
"""
19901989
s3_client = _validate_output_destination(
19911990
output_destination,
@@ -2002,17 +2001,17 @@ def benchmark_multi_table_aws(
20022001
limit_dataset_size=limit_dataset_size,
20032002
sdv_datasets=sdv_datasets,
20042003
additional_datasets_folder=additional_datasets_folder,
2004+
sdmetrics=None,
20052005
timeout=timeout,
20062006
output_destination=output_destination,
20072007
compute_quality_score=compute_quality_score,
20082008
compute_diagnostic_score=compute_diagnostic_score,
2009+
compute_privacy_score=None,
20092010
synthesizers=synthesizers,
20102011
detailed_results_folder=None,
20112012
custom_synthesizers=None,
20122013
s3_client=s3_client,
20132014
modality='multi_table',
2014-
sdmetrics=None,
2015-
compute_privacy_score=None,
20162015
)
20172016
if not job_args_list:
20182017
return _get_empty_dataframe(

sdgym/synthesizers/sdv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
MODEL_KWARGS = {'HMASynthesizer': {'verbose': False}}
2323

24-
MODEL_KWARGS = {'HMASynthesizer': {'verbose': False}}
25-
2624

2725
def _get_sdv_synthesizers(modality):
2826
_validate_modality(modality)

tests/unit/test_benchmark.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_validate_output_destination,
3131
_write_metainfo_file,
3232
benchmark_multi_table,
33+
benchmark_multi_table_aws,
3334
benchmark_single_table,
3435
benchmark_single_table_aws,
3536
)
@@ -1433,3 +1434,63 @@ def test_benchmark_multi_table_no_jobs(
14331434
)
14341435
mock__write_metainfo_file.assert_called_once()
14351436
pd.testing.assert_frame_equal(scores, empty_scores)
1437+
1438+
1439+
@patch('sdgym.benchmark._validate_output_destination')
1440+
@patch('sdgym.benchmark._generate_job_args_list')
1441+
@patch('sdgym.benchmark._run_on_aws')
1442+
def test_benchmark_multi_table_aws(
1443+
mock_run_on_aws, mock_generate_job_args_list, mock_validate_output_destination
1444+
):
1445+
"""Test `benchmark_multi_table_aws` method."""
1446+
# Setup
1447+
output_destination = 's3://sdgym-benchmark/Debug/Issue_487_test_1'
1448+
synthesizers = ['HMASynthesizer']
1449+
datasets = ['financial', 'NBA']
1450+
aws_access_key_id = '12345'
1451+
aws_secret_access_key = '67890'
1452+
mock_validate_output_destination.return_value = 's3_client_mock'
1453+
mock_generate_job_args_list.return_value = 'job_args_list_mock'
1454+
1455+
# Run
1456+
benchmark_multi_table_aws(
1457+
output_destination=output_destination,
1458+
aws_access_key_id=aws_access_key_id,
1459+
aws_secret_access_key=aws_secret_access_key,
1460+
synthesizers=synthesizers,
1461+
sdv_datasets=datasets,
1462+
)
1463+
1464+
# Assert
1465+
assert 'MultiTableUniformSynthesizer' in synthesizers
1466+
mock_validate_output_destination.assert_called_once_with(
1467+
output_destination,
1468+
aws_keys={
1469+
'aws_access_key_id': aws_access_key_id,
1470+
'aws_secret_access_key': aws_secret_access_key,
1471+
},
1472+
)
1473+
mock_generate_job_args_list.assert_called_once_with(
1474+
limit_dataset_size=False,
1475+
sdv_datasets=datasets,
1476+
additional_datasets_folder=None,
1477+
sdmetrics=None,
1478+
timeout=None,
1479+
output_destination=output_destination,
1480+
compute_quality_score=True,
1481+
compute_diagnostic_score=True,
1482+
compute_privacy_score=None,
1483+
synthesizers=synthesizers,
1484+
detailed_results_folder=None,
1485+
custom_synthesizers=None,
1486+
s3_client='s3_client_mock',
1487+
modality='multi_table',
1488+
)
1489+
mock_run_on_aws.assert_called_once_with(
1490+
output_destination=output_destination,
1491+
synthesizers=synthesizers,
1492+
s3_client='s3_client_mock',
1493+
job_args_list='job_args_list_mock',
1494+
aws_access_key_id='12345',
1495+
aws_secret_access_key='67890',
1496+
)

0 commit comments

Comments
 (0)