11"""Main SDGym benchmarking module."""
22
3- import base64
43import concurrent
54import logging
65import math
98import pickle
109import re
1110import textwrap
11+ import threading
1212import tracemalloc
1313import warnings
1414from collections import defaultdict
2424import numpy as np
2525import pandas as pd
2626import tqdm
27+ from botocore .config import Config
2728from sdmetrics .reports .multi_table import (
2829 DiagnosticReport as MultiTableDiagnosticReport ,
2930)
4243from sdgym .errors import SDGymError
4344from sdgym .metrics import get_metrics
4445from sdgym .progress import TqdmLogger , progress
45- from sdgym .result_writer import LocalResultsWriter
46+ from sdgym .result_writer import LocalResultsWriter , S3ResultsWriter
4647from sdgym .s3 import (
4748 S3_PREFIX ,
49+ S3_REGION ,
4850 is_s3_path ,
4951 parse_s3_path ,
5052 write_csv ,
@@ -168,6 +170,11 @@ def _setup_output_destination_aws(output_destination, synthesizers, datasets, s3
168170 'run_id' : f's3://{ bucket_name } /{ top_folder } /run_{ today } _{ increment } .yaml' ,
169171 }
170172
173+ s3_client .put_object (
174+ Bucket = bucket_name ,
175+ Key = f'{ top_folder } /run_{ today } _{ increment } .yaml' ,
176+ Body = 'completed_date: null\n ' .encode ('utf-8' ),
177+ )
171178 return paths
172179
173180
@@ -236,11 +243,25 @@ def _generate_job_args_list(
236243 synthesizers = get_synthesizers (synthesizers + custom_synthesizers )
237244
238245 # Get list of dataset paths
239- sdv_datasets = [] if sdv_datasets is None else get_dataset_paths (datasets = sdv_datasets )
246+ aws_access_key_id = os .getenv ('AWS_ACCESS_KEY_ID' )
247+ aws_secret_access_key_key = os .getenv ('AWS_SECRET_ACCESS_KEY' )
248+ sdv_datasets = (
249+ []
250+ if sdv_datasets is None
251+ else get_dataset_paths (
252+ datasets = sdv_datasets ,
253+ aws_access_key_id = aws_access_key_id ,
254+ aws_secret_access_key = aws_secret_access_key_key ,
255+ )
256+ )
240257 additional_datasets = (
241258 []
242259 if additional_datasets_folder is None
243- else get_dataset_paths (bucket = additional_datasets_folder )
260+ else get_dataset_paths (
261+ bucket = additional_datasets_folder ,
262+ aws_access_key_id = aws_access_key_id ,
263+ aws_secret_access_key = aws_secret_access_key_key ,
264+ )
244265 )
245266 datasets = sdv_datasets + additional_datasets
246267 synthesizer_names = [synthesizer ['name' ] for synthesizer in synthesizers ]
@@ -524,27 +545,36 @@ def _score_with_timeout(
524545 synthesizer_path = None ,
525546 result_writer = None ,
526547):
548+ output = {} if isinstance (result_writer , S3ResultsWriter ) else None
549+ args = (
550+ synthesizer ,
551+ data ,
552+ metadata ,
553+ metrics ,
554+ output ,
555+ compute_quality_score ,
556+ compute_diagnostic_score ,
557+ compute_privacy_score ,
558+ modality ,
559+ dataset_name ,
560+ synthesizer_path ,
561+ result_writer ,
562+ )
563+ if isinstance (result_writer , S3ResultsWriter ):
564+ thread = threading .Thread (target = _score , args = args , daemon = True )
565+ thread .start ()
566+ thread .join (timeout )
567+ if thread .is_alive ():
568+ LOGGER .error ('Timeout running %s on dataset %s;' , synthesizer ['name' ], dataset_name )
569+ return {'timeout' : True , 'error' : 'Timeout' }
570+
571+ return output
572+
527573 with multiprocessing_context ():
528574 with multiprocessing .Manager () as manager :
529575 output = manager .dict ()
530- process = multiprocessing .Process (
531- target = _score ,
532- args = (
533- synthesizer ,
534- data ,
535- metadata ,
536- metrics ,
537- output ,
538- compute_quality_score ,
539- compute_diagnostic_score ,
540- compute_privacy_score ,
541- modality ,
542- dataset_name ,
543- synthesizer_path ,
544- result_writer ,
545- ),
546- )
547-
576+ args = args [:4 ] + (output ,) + args [5 :] # replace output=None with manager.dict()
577+ process = multiprocessing .Process (target = _score , args = args )
548578 process .start ()
549579 process .join (timeout )
550580 process .terminate ()
@@ -697,7 +727,6 @@ def _run_job(args):
697727 compute_privacy_score ,
698728 cache_dir ,
699729 )
700-
701730 if synthesizer_path and result_writer :
702731 result_writer .write_dataframe (scores , synthesizer_path ['benchmark_result' ])
703732
@@ -998,9 +1027,10 @@ def _write_run_id_file(synthesizers, job_args_list, result_writer=None):
9981027 }
9991028 for synthesizer in synthesizers :
10001029 if synthesizer not in SDV_SINGLE_TABLE_SYNTHESIZERS :
1001- ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY [synthesizer ]
1002- library_version = version (ext_lib )
1003- metadata [f'{ ext_lib } _version' ] = library_version
1030+ ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY .get (synthesizer )
1031+ if ext_lib :
1032+ library_version = version (ext_lib )
1033+ metadata [f'{ ext_lib } _version' ] = library_version
10041034 elif 'sdv' not in metadata .keys ():
10051035 metadata ['sdv_version' ] = version ('sdv' )
10061036
@@ -1180,20 +1210,22 @@ def _validate_aws_inputs(output_destination, aws_access_key_id, aws_secret_acces
11801210 if not output_destination .startswith ('s3://' ):
11811211 raise ValueError ("'output_destination' must be an S3 URL starting with 's3://'. " )
11821212
1183- parsed_url = urlparse (output_destination )
1184- bucket_name = parsed_url .netloc
1213+ bucket_name , _ = parse_s3_path (output_destination )
11851214 if not bucket_name :
11861215 raise ValueError (f'Invalid S3 URL: { output_destination } ' )
11871216
1217+ config = Config (connect_timeout = 30 , read_timeout = 300 )
11881218 if aws_access_key_id and aws_secret_access_key :
11891219 s3_client = boto3 .client (
11901220 's3' ,
11911221 aws_access_key_id = aws_access_key_id ,
11921222 aws_secret_access_key = aws_secret_access_key ,
1223+ region_name = S3_REGION ,
1224+ config = config ,
11931225 )
11941226 else :
11951227 # No credentials provided — rely on default session
1196- s3_client = boto3 .client ('s3' )
1228+ s3_client = boto3 .client ('s3' , config = config )
11971229
11981230 s3_client .head_bucket (Bucket = bucket_name )
11991231 if not _check_write_permissions (s3_client , bucket_name ):
@@ -1223,8 +1255,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client):
12231255 job_args_key = f'{ path } { job_args_key } ' if path else job_args_key
12241256
12251257 serialized_data = pickle .dumps (job_args_list )
1226- encoded_data = base64 .b64encode (serialized_data ).decode ('utf-8' )
1227- s3_client .put_object (Bucket = bucket_name , Key = job_args_key , Body = encoded_data )
1258+ s3_client .put_object (Bucket = bucket_name , Key = job_args_key , Body = serialized_data )
12281259
12291260 return bucket_name , job_args_key
12301261
@@ -1235,15 +1266,6 @@ def _get_s3_script_content(
12351266 return f"""
12361267import boto3
12371268import pickle
1238- import base64
1239- import pandas as pd
1240- import sdgym
1241- from sdgym.synthesizers.sdv import (
1242- CopulaGANSynthesizer, CTGANSynthesizer,
1243- GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer,
1244- SDVRelationalSynthesizer, SDVTabularSynthesizer, TVAESynthesizer
1245- )
1246- from sdgym.synthesizers import RealTabFormerSynthesizer
12471269from sdgym.benchmark import _run_jobs, _write_run_id_file, _update_run_id_file
12481270from io import StringIO
12491271from sdgym.result_writer import S3ResultsWriter
@@ -1255,9 +1277,7 @@ def _get_s3_script_content(
12551277 region_name='{ region_name } '
12561278)
12571279response = s3_client.get_object(Bucket='{ bucket_name } ', Key='{ job_args_key } ')
1258- encoded_data = response['Body'].read().decode('utf-8')
1259- serialized_data = base64.b64decode(encoded_data.encode('utf-8'))
1260- job_args_list = pickle.loads(serialized_data)
1280+ job_args_list = pickle.loads(response['Body'].read())
12611281result_writer = S3ResultsWriter(s3_client=s3_client)
12621282_write_run_id_file({ synthesizers } , job_args_list, result_writer)
12631283scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
@@ -1287,7 +1307,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content):
12871307
12881308 echo "======== Install Dependencies in venv ============"
12891309 pip install --upgrade pip
1290- pip install "sdgym[all]"
1310+ pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-425-workflow-sdgym#egg=sdgym "
12911311 pip install s3fs
12921312
12931313 echo "======== Write Script ==========="
@@ -1313,11 +1333,10 @@ def _run_on_aws(
13131333 aws_secret_access_key ,
13141334):
13151335 bucket_name , job_args_key = _store_job_args_in_s3 (output_destination , job_args_list , s3_client )
1316- region_name = 'us-east-1'
13171336 script_content = _get_s3_script_content (
13181337 aws_access_key_id ,
13191338 aws_secret_access_key ,
1320- region_name ,
1339+ S3_REGION ,
13211340 bucket_name ,
13221341 job_args_key ,
13231342 synthesizers ,
@@ -1327,12 +1346,12 @@ def _run_on_aws(
13271346 session = boto3 .session .Session (
13281347 aws_access_key_id = aws_access_key_id ,
13291348 aws_secret_access_key = aws_secret_access_key ,
1330- region_name = region_name ,
1349+ region_name = S3_REGION ,
13311350 )
13321351 ec2_client = session .client ('ec2' )
13331352 print (f'This instance is being created in region: { session .region_name } ' ) # noqa
13341353 user_data_script = _get_user_data_script (
1335- aws_access_key_id , aws_secret_access_key , region_name , script_content
1354+ aws_access_key_id , aws_secret_access_key , S3_REGION , script_content
13361355 )
13371356 response = ec2_client .run_instances (
13381357 ImageId = 'ami-080e1f13689e07408' ,
0 commit comments