Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdgym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
benchmark_multi_table,
benchmark_single_table,
benchmark_single_table_aws,
benchmark_multi_table_aws,
)
from sdgym.cli.collect import collect_results
from sdgym.cli.summary import make_summary_spreadsheet
Expand All @@ -36,6 +37,7 @@
'DatasetExplorer',
'ResultsExplorer',
'benchmark_multi_table',
'benchmark_multi_table_aws',
'benchmark_single_table',
'benchmark_single_table_aws',
'collect_results',
Expand Down
145 changes: 131 additions & 14 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
'TVAESynthesizer',
]
SDV_MULTI_TABLE_SYNTHESIZERS = ['HMASynthesizer']

MODALITY_IDX = 10
SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS


Expand Down Expand Up @@ -220,32 +220,46 @@ def _get_metainfo_increment(top_folder, s3_client=None):
return max(increments) + 1 if increments else 0


def _setup_output_destination_aws(output_destination, synthesizers, datasets, s3_client):
def _setup_output_destination_aws(
output_destination,
synthesizers,
datasets,
modality,
s3_client,
):
paths = defaultdict(dict)
s3_path = output_destination[len(S3_PREFIX) :].rstrip('/')
parts = s3_path.split('/')
bucket_name = parts[0]
prefix_parts = parts[1:]
paths['bucket_name'] = bucket_name
today = datetime.today().strftime('%m_%d_%Y')
top_folder = '/'.join(prefix_parts + [f'SDGym_results_{today}'])

modality_prefix = '/'.join(prefix_parts + [modality])
top_folder = f'{modality_prefix}/SDGym_results_{today}'
increment = _get_metainfo_increment(f's3://{bucket_name}/{top_folder}', s3_client)
suffix = f'({increment})' if increment >= 1 else ''
s3_client.put_object(Bucket=bucket_name, Key=top_folder + '/')
synthetic_data_extension = 'zip' if modality == 'multi_table' else 'csv'
for dataset in datasets:
dataset_folder = f'{top_folder}/{dataset}_{today}'
s3_client.put_object(Bucket=bucket_name, Key=dataset_folder + '/')
paths[dataset]['meta'] = f's3://{bucket_name}/{dataset_folder}/meta.yaml'

for synth_name in synthesizers:
final_synth_name = f'{synth_name}{suffix}'
synth_folder = f'{dataset_folder}/{final_synth_name}'
s3_client.put_object(Bucket=bucket_name, Key=synth_folder + '/')
paths[dataset][final_synth_name] = {
'synthesizer': f's3://{bucket_name}/{synth_folder}/{final_synth_name}.pkl',
'synthetic_data': f's3://{bucket_name}/{synth_folder}/{final_synth_name}_synthetic_data.csv',
'benchmark_result': f's3://{bucket_name}/{synth_folder}/{final_synth_name}_benchmark_result.csv',
'results': f's3://{bucket_name}/{top_folder}/results{suffix}.csv',
'metainfo': f's3://{bucket_name}/{top_folder}/metainfo{suffix}.yaml',
'synthesizer': (f's3://{bucket_name}/{synth_folder}/{final_synth_name}.pkl'),
'synthetic_data': (
f's3://{bucket_name}/{synth_folder}/'
f'{final_synth_name}_synthetic_data.{synthetic_data_extension}'
),
'benchmark_result': (
f's3://{bucket_name}/{synth_folder}/{final_synth_name}_benchmark_result.csv'
),
'metainfo': (f's3://{bucket_name}/{top_folder}/metainfo{suffix}.yaml'),
'results': (f's3://{bucket_name}/{top_folder}/results{suffix}.csv'),
}

s3_client.put_object(
Expand Down Expand Up @@ -279,7 +293,9 @@ def _setup_output_destination(
The s3 client that can be used to read / write to s3. Defaults to ``None``.
"""
if s3_client:
return _setup_output_destination_aws(output_destination, synthesizers, datasets, s3_client)
return _setup_output_destination_aws(
output_destination, synthesizers, datasets, modality, s3_client
)

if output_destination is None:
return {}
Expand Down Expand Up @@ -1571,7 +1587,7 @@ def _get_s3_script_content(
return f"""
import boto3
import cloudpickle
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file, MODALITY_IDX
from io import StringIO
from sdgym.result_writer import S3ResultsWriter

Expand All @@ -1583,8 +1599,9 @@ def _get_s3_script_content(
)
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
job_args_list = cloudpickle.loads(response['Body'].read())
modality = job_args_list[0][MODALITY_IDX]
result_writer = S3ResultsWriter(s3_client=s3_client)
_write_metainfo_file({synthesizers}, job_args_list, 'single_table', result_writer)
_write_metainfo_file({synthesizers}, job_args_list, modality, result_writer)
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
metainfo_filename = job_args_list[0][-1]['metainfo']
_update_metainfo_file(metainfo_filename, result_writer)
Expand Down Expand Up @@ -1619,7 +1636,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content):

echo "======== Install Dependencies in venv ============"
pip install --upgrade pip
pip install sdgym[all]
pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@feature_branch/mutli_table_benchmark"
pip install s3fs

echo "======== Write Script ==========="
Expand All @@ -1644,13 +1661,14 @@ def _run_on_aws(
aws_secret_access_key,
):
bucket_name, job_args_key = _store_job_args_in_s3(output_destination, job_args_list, s3_client)
synthesizer_names = [{'name': synthesizer['name']} for synthesizer in synthesizers]
script_content = _get_s3_script_content(
aws_access_key_id,
aws_secret_access_key,
S3_REGION,
bucket_name,
job_args_key,
synthesizers,
synthesizer_names,
)

# Create a session and EC2 client using the provided S3 client's credentials
Expand Down Expand Up @@ -1917,3 +1935,102 @@ def benchmark_multi_table(
_update_metainfo_file(metainfo_filename, result_writer)

return scores


def benchmark_multi_table_aws(
output_destination,
aws_access_key_id=None,
aws_secret_access_key=None,
synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
additional_datasets_folder=None,
limit_dataset_size=False,
compute_quality_score=True,
compute_diagnostic_score=True,
timeout=None,
):
"""Run the SDGym benchmark on multi-table datasets.

Args:
output_destination (str):
An S3 bucket or filepath. The results output folder will be written here.
Should be structured as:
s3://{s3_bucket_name}/{path_to_file} or s3://{s3_bucket_name}.
aws_access_key_id (str): The AWS access key id. Optional
aws_secret_access_key (str): The AWS secret access key. Optional
synthesizers (list[string]):
The synthesizer(s) to evaluate. Defaults to
``[HMASynthesizer, MultiTableUniformSynthesizer]``. The available options
are:
- ``HMASynthesizer``
- ``MultiTableUniformSynthesizer``
sdv_datasets (list[str] or ``None``):
Names of the SDV demo datasets to use for the benchmark. Defaults to
``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
covtype]``. Use ``None`` to disable using any sdv datasets.
additional_datasets_folder (str or ``None``):
The path to an S3 bucket. Datasets found in this folder are
run in addition to the SDV datasets. If ``None``, no additional datasets are used.
limit_dataset_size (bool):
Use this flag to limit the size of the datasets for faster evaluation. If ``True``,
limit the size of every table to 1,000 rows (randomly sampled) and the first 10
columns.
compute_quality_score (bool):
Whether or not to evaluate an overall quality score. Defaults to ``True``.
compute_diagnostic_score (bool):
Whether or not to evaluate an overall diagnostic score. Defaults to ``True``.
timeout (int or ``None``):
The maximum number of seconds to wait for synthetic data creation. If ``None``, no
timeout is enforced.

Returns:
pandas.DataFrame:
A table containing one row per synthesizer + dataset.
"""
s3_client = _validate_output_destination(
output_destination,
aws_keys={
'aws_access_key_id': aws_access_key_id,
'aws_secret_access_key': aws_secret_access_key,
},
)
if not synthesizers:
synthesizers = []

_ensure_uniform_included(synthesizers, modality='multi_table')
synthesizers = _import_and_validate_synthesizers(
synthesizers=synthesizers,
custom_synthesizers=None,
modality='multi_table',
)
job_args_list = _generate_job_args_list(
limit_dataset_size=limit_dataset_size,
sdv_datasets=sdv_datasets,
additional_datasets_folder=additional_datasets_folder,
sdmetrics=None,
timeout=timeout,
output_destination=output_destination,
compute_quality_score=compute_quality_score,
compute_diagnostic_score=compute_diagnostic_score,
compute_privacy_score=None,
synthesizers=synthesizers,
detailed_results_folder=None,
s3_client=s3_client,
modality='multi_table',
)
if not job_args_list:
return _get_empty_dataframe(
compute_diagnostic_score=compute_diagnostic_score,
compute_quality_score=compute_quality_score,
compute_privacy_score=None,
sdmetrics=None,
)

_run_on_aws(
output_destination=output_destination,
synthesizers=synthesizers,
s3_client=s3_client,
job_args_list=job_args_list,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
13 changes: 13 additions & 0 deletions sdgym/result_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,16 @@ def write_yaml(self, data, file_path, append=False):
run_data.update(data)
new_content = yaml.dump(run_data)
self.s3_client.put_object(Body=new_content.encode(), Bucket=bucket, Key=key)

def write_zipped_dataframes(self, data, file_path, index=False):
"""Write a dictionary of DataFrames to a ZIP file in S3."""
bucket, key = parse_s3_path(file_path)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
for table_name, table in data.items():
csv_buf = io.StringIO()
table.to_csv(csv_buf, index=index)
zf.writestr(f'{table_name}.csv', csv_buf.getvalue())

zip_buffer.seek(0)
self.s3_client.upload_fileobj(zip_buffer, bucket, key)
Loading
Loading