11"""Script to run a benchmark and upload results to S3."""
22
3- import argparse
4- import base64
53import json
64import os
75from datetime import datetime , timezone
8- from pathlib import Path
96
107from botocore .exceptions import ClientError
118
12- from sdgym ._benchmark .benchmark import _benchmark_multi_table_compute_gcp
139from sdgym .benchmark import benchmark_single_table_aws
1410from sdgym .run_benchmark .utils import (
15- GCP_PROJECT ,
16- GCP_ZONE ,
1711 KEY_DATE_FILE ,
1812 OUTPUT_DESTINATION_AWS ,
19- SYNTHESIZERS_SPLIT_MULTI_TABLE ,
20- SYNTHESIZERS_SPLIT_SINGLE_TABLE ,
13+ SYNTHESIZERS_SPLIT ,
2114 get_result_folder_name ,
2215 post_benchmark_launch_message ,
2316)
2417from sdgym .s3 import get_s3_client , parse_s3_path
2518
2619
27- def append_benchmark_run (
28- aws_access_key_id , aws_secret_access_key , date_str , modality = 'single_table'
29- ):
20+ def append_benchmark_run (aws_access_key_id , aws_secret_access_key , date_str ):
3021 """Append a new benchmark run to the benchmark dates file in S3."""
3122 s3_client = get_s3_client (
3223 aws_access_key_id = aws_access_key_id ,
3324 aws_secret_access_key = aws_secret_access_key ,
3425 )
3526 bucket , prefix = parse_s3_path (OUTPUT_DESTINATION_AWS )
3627 try :
37- object = s3_client .get_object (Bucket = bucket , Key = f'{ prefix } { modality } { KEY_DATE_FILE } ' )
28+ object = s3_client .get_object (Bucket = bucket , Key = f'{ prefix } { KEY_DATE_FILE } ' )
3829 body = object ['Body' ].read ().decode ('utf-8' )
3930 data = json .loads (body )
4031 except ClientError as e :
@@ -50,116 +41,23 @@ def append_benchmark_run(
5041 )
5142
5243
53- def _load_gcp_service_account_from_env ():
54- """Load GCP service account JSON from env.
55-
56- Supports:
57- - raw JSON string
58- - base64-encoded JSON string
59- """
60- raw = os .getenv ('GCP_SERVICE_ACCOUNT_JSON' , '' ) or ''
61- if not raw .strip ():
62- return {}
63-
64- try :
65- return json .loads (raw )
66- except json .JSONDecodeError :
67- decoded = base64 .b64decode (raw ).decode ('utf-8' )
68- return json .loads (decoded )
69-
70-
71- def create_credentials_file (filepath ):
72- """Create credentials file used by the benchmark launcher."""
73- gcp_sa = _load_gcp_service_account_from_env ()
74-
75- credentials = {
76- 'aws' : {
77- 'aws_access_key_id' : os .getenv ('AWS_ACCESS_KEY_ID' ),
78- 'aws_secret_access_key' : os .getenv ('AWS_SECRET_ACCESS_KEY' ),
79- },
80- 'gcp' : {
81- ** gcp_sa ,
82- 'gcp_project' : GCP_PROJECT ,
83- 'gcp_zone' : GCP_ZONE ,
84- },
85- 'sdv' : {
86- 'username' : os .getenv ('SDV_ENTERPRISE_USERNAME' ),
87- 'license_key' : os .getenv ('SDV_ENTERPRISE_LICENSE_KEY' ),
88- },
89- }
90-
91- Path (filepath ).parent .mkdir (parents = True , exist_ok = True )
92- with open (filepath , 'w' , encoding = 'utf-8' ) as f :
93- json .dump (credentials , f , indent = 2 , sort_keys = True )
94- f .write ('\n ' )
95-
96-
97- def _parse_args ():
98- parser = argparse .ArgumentParser ()
99- parser .add_argument (
100- '--modality' ,
101- choices = ['single_table' , 'multi_table' ],
102- default = 'single_table' ,
103- help = 'Benchmark modality to run.' ,
104- )
105- parser .add_argument (
106- '--gcp-output-destination' ,
107- default = 's3://sdgym-benchmark/Debug/GCP/' ,
108- help = 'Where to store GCP benchmark results (S3).' ,
109- )
110- return parser .parse_args ()
111-
112-
11344def main ():
11445 """Main function to run the benchmark and upload results."""
115- args = _parse_args ()
116-
11746 aws_access_key_id = os .getenv ('AWS_ACCESS_KEY_ID' )
11847 aws_secret_access_key = os .getenv ('AWS_SECRET_ACCESS_KEY' )
11948 date_str = datetime .now (timezone .utc ).strftime ('%Y-%m-%d' )
120-
121- if args .modality == 'single_table' :
122- for synthesizer_group in SYNTHESIZERS_SPLIT_SINGLE_TABLE :
123- benchmark_single_table_aws (
124- output_destination = OUTPUT_DESTINATION_AWS ,
125- aws_access_key_id = aws_access_key_id ,
126- aws_secret_access_key = aws_secret_access_key ,
127- synthesizers = synthesizer_group ,
128- compute_privacy_score = False ,
129- timeout = 345600 , # 4 days
130- )
131-
132- append_benchmark_run (
133- aws_access_key_id ,
134- aws_secret_access_key ,
135- date_str ,
136- modality = 'single_table' ,
137- )
138- compute_service = 'AWS'
139-
140- else :
141- runner_temp = os .environ .get ('RUNNER_TEMP' , '/tmp' )
142- cred_path = os .path .join (runner_temp , 'credentials.json' )
143- create_credentials_file (cred_path )
144-
145- for synthesizer_group in SYNTHESIZERS_SPLIT_MULTI_TABLE :
146- _benchmark_multi_table_compute_gcp (
147- output_destination = args .gcp_output_destination ,
148- credential_filepath = cred_path ,
149- synthesizers = synthesizer_group ,
150- compute_privacy_score = False ,
151- timeout = 345600 , # 4 days
152- )
153-
154- append_benchmark_run (
155- aws_access_key_id ,
156- aws_secret_access_key ,
157- date_str ,
158- modality = 'multi_table' ,
49+ for synthesizer_group in SYNTHESIZERS_SPLIT :
50+ benchmark_single_table_aws (
51+ output_destination = OUTPUT_DESTINATION_AWS ,
52+ aws_access_key_id = aws_access_key_id ,
53+ aws_secret_access_key = aws_secret_access_key ,
54+ synthesizers = synthesizer_group ,
55+ compute_privacy_score = False ,
56+ timeout = 345600 , # 4 days
15957 )
160- compute_service = 'GCP'
16158
162- post_benchmark_launch_message (date_str , compute_service = compute_service )
59+ append_benchmark_run (aws_access_key_id , aws_secret_access_key , date_str )
60+ post_benchmark_launch_message (date_str )
16361
16462
16563if __name__ == '__main__' :
0 commit comments