11"""Script to run a benchmark and upload results to S3."""
22
3+ import argparse
34import json
45import os
56from datetime import datetime , timezone
67
78from botocore .exceptions import ClientError
89
9- from sdgym .benchmark import benchmark_single_table_aws
10+ from sdgym .benchmark import _benchmark_multi_table_compute_gcp , benchmark_single_table_aws
1011from sdgym .run_benchmark .utils import (
1112 KEY_DATE_FILE ,
1213 OUTPUT_DESTINATION_AWS ,
13- SYNTHESIZERS_SPLIT ,
14+ SYNTHESIZERS_SPLIT_MULTI_TABLE ,
15+ SYNTHESIZERS_SPLIT_SINGLE_TABLE ,
1416 get_result_folder_name ,
1517 post_benchmark_launch_message ,
1618)
1719from sdgym .s3 import get_s3_client , parse_s3_path
1820
1921
20- def append_benchmark_run (aws_access_key_id , aws_secret_access_key , date_str ):
22+ def append_benchmark_run (aws_access_key_id , aws_secret_access_key , date_str , modality = 'single_table' ):
2123 """Append a new benchmark run to the benchmark dates file in S3."""
2224 s3_client = get_s3_client (
2325 aws_access_key_id = aws_access_key_id ,
2426 aws_secret_access_key = aws_secret_access_key ,
2527 )
2628 bucket , prefix = parse_s3_path (OUTPUT_DESTINATION_AWS )
2729 try :
28- object = s3_client .get_object (Bucket = bucket , Key = f'{ prefix } { KEY_DATE_FILE } ' )
30+ object = s3_client .get_object (Bucket = bucket , Key = f'{ prefix } { modality } { KEY_DATE_FILE } ' )
2931 body = object ['Body' ].read ().decode ('utf-8' )
3032 data = json .loads (body )
3133 except ClientError as e :
@@ -41,23 +43,50 @@ def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
4143 )
4244
4345
46+ def _parse_args ():
47+ parser = argparse .ArgumentParser ()
48+ parser .add_argument (
49+ '--modality' ,
50+ choices = ['single_table' , 'multi_table' ],
51+ default = 'single_table' ,
52+ help = 'Benchmark modality to run.' ,
53+ )
54+ return parser .parse_args ()
55+
56+
4457def main ():
4558 """Main function to run the benchmark and upload results."""
59+ args = _parse_args ()
4660 aws_access_key_id = os .getenv ('AWS_ACCESS_KEY_ID' )
4761 aws_secret_access_key = os .getenv ('AWS_SECRET_ACCESS_KEY' )
4862 date_str = datetime .now (timezone .utc ).strftime ('%Y-%m-%d' )
49- for synthesizer_group in SYNTHESIZERS_SPLIT :
50- benchmark_single_table_aws (
51- output_destination = OUTPUT_DESTINATION_AWS ,
52- aws_access_key_id = aws_access_key_id ,
53- aws_secret_access_key = aws_secret_access_key ,
54- synthesizers = synthesizer_group ,
55- compute_privacy_score = False ,
56- timeout = 345600 , # 4 days
57- )
58-
59- append_benchmark_run (aws_access_key_id , aws_secret_access_key , date_str )
60- post_benchmark_launch_message (date_str )
63+
64+ if args .modality == 'single_table' :
65+ for synthesizer_group in SYNTHESIZERS_SPLIT_SINGLE_TABLE :
66+ benchmark_single_table_aws (
67+ output_destination = OUTPUT_DESTINATION_AWS ,
68+ aws_access_key_id = aws_access_key_id ,
69+ aws_secret_access_key = aws_secret_access_key ,
70+ synthesizers = synthesizer_group ,
71+ compute_privacy_score = False ,
72+ timeout = 345600 , # 4 days
73+ )
74+
75+ append_benchmark_run (aws_access_key_id , aws_secret_access_key , date_str , modality = 'single_table' )
76+
77+ else :
78+ for synthesizer_group in SYNTHESIZERS_SPLIT_MULTI_TABLE :
79+ _benchmark_multi_table_compute_gcp (
80+ output_destination = 's3://sdgym-benchmark/Debug/GCP/' ,
81+ aws_access_key_id = aws_access_key_id ,
82+ aws_secret_access_key = aws_secret_access_key ,
83+ synthesizers = synthesizer_group ,
84+ compute_privacy_score = False ,
85+ timeout = 345600 , # 4 days
86+ )
87+ append_benchmark_run (aws_access_key_id , aws_secret_access_key , date_str , modality = 'multi_table' )
88+
89+ post_benchmark_launch_message (date_str , compute_service = 'GCP' )
6190
6291
6392if __name__ == '__main__' :
0 commit comments