8181 'token_uri' ,
8282 'auth_provider_x509_cert_url' ,
8383 'client_x509_cert_url' ,
84+ 'universe_domain' ,
85+ 'gcp_project' ,
86+ 'gcp_zone' ,
8487 },
85- 'sdv' : {'username' , 'sdv_license_key ' },
88+ 'sdv' : {'username' , 'license_key ' },
8689}
8790LOGGER = logging .getLogger (__name__ )
8891DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [
@@ -1628,66 +1631,80 @@ def _get_s3_script_content(
16281631"""
16291632
16301633
1631- def _get_user_data_script (access_key , secret_key , region_name , script_content ):
1632- return textwrap .dedent (f"""\
1633- #!/bin/bash
1634- set -e
1635-
1636- # Always terminate the instance when the script exits (success or failure)
1637- trap '
1638- INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id);
1639- echo "======== Terminating EC2 instance: $INSTANCE_ID ==========";
1640- aws ec2 terminate-instances --instance-ids $INSTANCE_ID;
1641- ' EXIT
1642-
1643- exec > >(tee /var/log/user-data.log|logger -t user-data -s 2>/dev/console) 2>&1
1644- echo "======== Update and Install Dependencies ============"
1645- sudo apt update -y
1646- sudo apt install -y python3-pip python3-venv awscli
1647- echo "======== Configure AWS CLI ============"
1648- aws configure set aws_access_key_id '{ access_key } '
1649- aws configure set aws_secret_access_key '{ secret_key } '
1650- aws configure set default.region '{ region_name } '
1651-
1652- echo "======== Create Virtual Environment ============"
1653- python3 -m venv ~/env
1654- source ~/env/bin/activate
1655-
1656- echo "======== Install Dependencies in venv ============"
1657- pip install --upgrade pip
1658- pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@feature_branch/mutli_table_benchmark"
1659- pip install s3fs
1660-
1661- echo "======== Write Script ==========="
1662- cat << 'EOF' > ~/sdgym_script.py
1663- { script_content }
1664- EOF
1665-
1666- echo "======== Run Script ==========="
1667- python ~/sdgym_script.py
1668- echo "======== Complete ==========="
1669- INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
1670- aws ec2 terminate-instances --instance-ids $INSTANCE_ID
1671- """ ).strip ()
1634+ def _get_user_data_script (credentials , script_content , compute_service = 'aws' ):
1635+ """Generate user-data for either AWS or GCP.
16721636
1637+ Args:
1638+ credentials: dict with AWS and optional SDV credentials.
1639+ script_content: Python script to write and execute.
1640+ compute_service: "aws" or "gcp"
1641+ """
1642+ aws_key = credentials ['aws' ]['aws_access_key_id' ]
1643+ aws_secret = credentials ['aws' ]['aws_secret_access_key' ]
1644+
1645+ # Conditional bundle-xsynthesizers install
1646+ sdv_creds = credentials .get ('sdv' , {})
1647+ bundle_install = ''
1648+ if sdv_creds .get ('username' ) and sdv_creds .get ('license_key' ):
1649+ bundle_install = (
1650+ f'pip install bundle-xsynthesizers '
1651+ f'--index-url https://{ sdv_creds ["username" ]} :{ sdv_creds ["license_key" ]} @pypi.datacebo.com'
1652+ )
16731653
1674- def _get_gcp_script (credentials , script_content ):
1654+ # --- Build termination trap depending on compute service ---
1655+ if compute_service == 'aws' :
1656+ termination_trap = textwrap .dedent ("""\
1657+ # AWS termination
1658+ INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id || true)
1659+ if [ ! -z "$INSTANCE_ID" ]; then
1660+ echo "Terminating AWS EC2 instance: $INSTANCE_ID"
1661+ aws ec2 terminate-instances --instance-ids $INSTANCE_ID || true
1662+ fi
1663+ """ )
1664+
1665+ elif compute_service == 'gcp' :
1666+ termination_trap = textwrap .dedent ("""\
1667+ # GCP termination via Compute API (no gcloud required)
1668+ echo "Detected GCP environment — terminating instance via Compute API"
1669+
1670+ TOKEN=$(curl -s -H "Metadata-Flavor: Google" \
1671+ http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token \
1672+ | python3 -c "import sys, json; print(json.load(sys.stdin)['access_token'])")
1673+
1674+ PROJECT_ID=$(curl -s -H "Metadata-Flavor: Google" \
1675+ http://169.254.169.254/computeMetadata/v1/project/project-id)
1676+
1677+ ZONE=$(curl -s -H "Metadata-Flavor: Google" \
1678+ http://169.254.169.254/computeMetadata/v1/instance/zone | awk -F/ '{print $4}')
1679+
1680+ curl -s -X DELETE \
1681+ -H "Authorization: Bearer $TOKEN" \
1682+ -H "Content-Type: application/json" \
1683+ "https://compute.googleapis.com/compute/v1/projects/$PROJECT_ID/zones/$ZONE/" \
1684+ "instances/$HOSTNAME" \
1685+ || true
1686+ """ )
1687+
1688+ # --- Final script assembly ---
16751689 return textwrap .dedent (f"""\
16761690 #!/bin/bash
16771691 set -e
16781692
1679- # Always terminate the instance when the script exits (success or failure)
1693+ # Auto-termination trap
16801694 trap '
1681- gcloud compute instances delete "$HOSTNAME" --zone=us-central1-a --quiet
1695+ echo "======== Auto-Termination Triggered =========="
1696+ { termination_trap }
16821697 ' EXIT
16831698
1684- exec > >(tee /var/log/user-data.log|logger -t user-data -s 2>/dev/console) 2>&1
1699+ exec > >(tee /var/log/user-data.log | logger -t user-data -s 2>/dev/console) 2>&1
1700+
16851701 echo "======== Update and Install Dependencies ============"
16861702 sudo apt update -y
1687- sudo apt install -y python3-pip python3-venv awscli git
1688- echo "======== Configure AWS CLI ============"
1689- aws configure set aws_access_key_id '{ credentials ['aws' ]['aws_access_key_id' ]} '
1690- aws configure set aws_secret_access_key '{ credentials ['aws' ]['aws_secret_access_key' ]} '
1703+ sudo apt install -y python3-pip python3-venv awscli git jq
1704+
1705+ echo "======== Configure AWS CLI (always needed for S3 output) ============"
1706+ aws configure set aws_access_key_id '{ aws_key } '
1707+ aws configure set aws_secret_access_key '{ aws_secret } '
16911708 aws configure set default.region '{ S3_REGION } '
16921709
16931710 echo "======== Create Virtual Environment ============"
@@ -1696,8 +1713,9 @@ def _get_gcp_script(credentials, script_content):
16961713
16971714 echo "======== Install Dependencies in venv ============"
16981715 pip install --upgrade pip
1699- pip install bundle-xsynthesizers --index-url https://{ credentials ['sdv' ]['username' ]} :{ credentials ['sdv' ]['license_key' ]} @pypi.datacebo.com
1700- pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@gcp_benchmark-romain"
1716+ { bundle_install }
1717+ pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@gcp-benchmark-romain"
1718+ pip install s3fs
17011719
17021720 echo "======== Write Script ==========="
17031721 cat << 'EOF' > ~/sdgym_script.py
@@ -1706,8 +1724,8 @@ def _get_gcp_script(credentials, script_content):
17061724
17071725 echo "======== Run Script ==========="
17081726 python ~/sdgym_script.py
1727+
17091728 echo "======== Complete ==========="
1710- gcloud compute instances delete "$HOSTNAME" --zone=us-central1-a --quiet
17111729 """ ).strip ()
17121730
17131731
@@ -1730,9 +1748,10 @@ def _run_on_gcp(output_destination, synthesizers, s3_client, job_args_list, cred
17301748
17311749 machine_type = f'zones/{ gcp_zone } /machineTypes/e2-standard-8'
17321750 source_disk_image = 'projects/debian-cloud/global/images/family/debian-12'
1733- startup_script = _get_gcp_script (
1734- credentials = gcp_credentials ,
1751+ startup_script = _get_user_data_script (
1752+ credentials = credentials ,
17351753 script_content = script_content ,
1754+ compute_service = 'gcp' ,
17361755 )
17371756
17381757 instance_client = compute_v1 .InstancesClient (credentials = gcp_credentials )
@@ -2182,8 +2201,7 @@ def _get_credentials(credential_filepath):
21822201 f'Found: { actual_sections } .'
21832202 )
21842203
2185- for section , keys in CREDENTIAL_KEYS .items ():
2186- expected_keys = set (keys .keys ())
2204+ for section , expected_keys in CREDENTIAL_KEYS .items ():
21872205 actual_keys = set (credentials [section ].keys ())
21882206 if expected_keys != actual_keys :
21892207 raise ValueError (
@@ -2194,6 +2212,72 @@ def _get_credentials(credential_filepath):
21942212 return credentials
21952213
21962214
2215+ def _benchmark_compute_gcp (
2216+ output_destination ,
2217+ credential_filepath ,
2218+ synthesizers ,
2219+ sdv_datasets ,
2220+ additional_datasets_folder ,
2221+ limit_dataset_size ,
2222+ compute_quality_score ,
2223+ compute_diagnostic_score ,
2224+ compute_privacy_score ,
2225+ sdmetrics ,
2226+ timeout ,
2227+ modality ,
2228+ ):
2229+ """Run the SDGym benchmark on datasets for the given modality."""
2230+ credentials = _get_credentials (credential_filepath )
2231+ s3_client = _validate_output_destination (
2232+ output_destination ,
2233+ aws_keys = {
2234+ 'aws_access_key_id' : credentials ['aws' ]['aws_access_key_id' ],
2235+ 'aws_secret_access_key' : credentials ['aws' ]['aws_secret_access_key' ],
2236+ },
2237+ )
2238+
2239+ if not synthesizers :
2240+ synthesizers = []
2241+
2242+ _ensure_uniform_included (synthesizers , modality )
2243+ synthesizers = _import_and_validate_synthesizers (
2244+ synthesizers = synthesizers ,
2245+ custom_synthesizers = None ,
2246+ modality = modality ,
2247+ )
2248+
2249+ job_args_list = _generate_job_args_list (
2250+ limit_dataset_size = limit_dataset_size ,
2251+ sdv_datasets = sdv_datasets ,
2252+ additional_datasets_folder = additional_datasets_folder ,
2253+ sdmetrics = sdmetrics ,
2254+ timeout = timeout ,
2255+ output_destination = output_destination ,
2256+ compute_quality_score = compute_quality_score ,
2257+ compute_diagnostic_score = compute_diagnostic_score ,
2258+ compute_privacy_score = compute_privacy_score ,
2259+ synthesizers = synthesizers ,
2260+ detailed_results_folder = None ,
2261+ s3_client = s3_client ,
2262+ modality = modality ,
2263+ )
2264+ if not job_args_list :
2265+ return _get_empty_dataframe (
2266+ compute_diagnostic_score = compute_diagnostic_score ,
2267+ compute_quality_score = compute_quality_score ,
2268+ compute_privacy_score = compute_privacy_score ,
2269+ sdmetrics = sdmetrics ,
2270+ )
2271+
2272+ _run_on_gcp (
2273+ output_destination = output_destination ,
2274+ synthesizers = synthesizers ,
2275+ s3_client = s3_client ,
2276+ job_args_list = job_args_list ,
2277+ credentials = credentials ,
2278+ )
2279+
2280+
21972281def _benchmark_single_table_compute_gcp (
21982282 output_destination ,
21992283 credential_filepath ,
@@ -2254,50 +2338,85 @@ def _benchmark_single_table_compute_gcp(
22542338 pandas.DataFrame:
22552339 A table containing one row per synthesizer + dataset + metric.
22562340 """
2257- credentials = _get_credentials (credential_filepath )
2258- s3_client = _validate_output_destination (
2259- output_destination ,
2260- aws_keys = {
2261- 'aws_access_key_id' : credentials ['aws' ]['aws_access_key_id' ],
2262- 'aws_secret_access_key' : credentials ['aws' ]['aws_secret_access_key' ],
2263- },
2264- )
2265- if not synthesizers :
2266- synthesizers = []
2267-
2268- _ensure_uniform_included (synthesizers , 'single_table' )
2269- synthesizers = _import_and_validate_synthesizers (
2341+ return _benchmark_compute_gcp (
2342+ output_destination = output_destination ,
2343+ credential_filepath = credential_filepath ,
22702344 synthesizers = synthesizers ,
2271- custom_synthesizers = None ,
2272- modality = 'single_table' ,
2273- )
2274- job_args_list = _generate_job_args_list (
2275- limit_dataset_size = limit_dataset_size ,
22762345 sdv_datasets = sdv_datasets ,
22772346 additional_datasets_folder = additional_datasets_folder ,
2278- sdmetrics = sdmetrics ,
2279- timeout = timeout ,
2280- output_destination = output_destination ,
2347+ limit_dataset_size = limit_dataset_size ,
22812348 compute_quality_score = compute_quality_score ,
22822349 compute_diagnostic_score = compute_diagnostic_score ,
22832350 compute_privacy_score = compute_privacy_score ,
2284- synthesizers = synthesizers ,
2285- detailed_results_folder = None ,
2286- s3_client = s3_client ,
2351+ sdmetrics = sdmetrics ,
2352+ timeout = timeout ,
22872353 modality = 'single_table' ,
22882354 )
2289- if not job_args_list :
2290- return _get_empty_dataframe (
2291- compute_diagnostic_score = compute_diagnostic_score ,
2292- compute_quality_score = compute_quality_score ,
2293- compute_privacy_score = compute_privacy_score ,
2294- sdmetrics = sdmetrics ,
2295- )
22962355
2297- _run_on_gcp (
2356+
2357+ def _benchmark_multi_table_compute_gcp (
2358+ output_destination ,
2359+ credential_filepath ,
2360+ synthesizers = DEFAULT_MULTI_TABLE_SYNTHESIZERS ,
2361+ sdv_datasets = DEFAULT_MULTI_TABLE_DATASETS ,
2362+ additional_datasets_folder = None ,
2363+ limit_dataset_size = False ,
2364+ compute_quality_score = True ,
2365+ compute_diagnostic_score = True ,
2366+ compute_privacy_score = True ,
2367+ sdmetrics = None ,
2368+ timeout = None ,
2369+ ):
2370+ """Run the SDGym benchmark on multi-table datasets.
2371+
2372+ Args:
2373+ output_destination (str):
2374+ An S3 bucket or filepath. The results output folder will be written here.
2375+ Should be structured as:
2376+ s3://{s3_bucket_name}/{path_to_file} or s3://{s3_bucket_name}.
2377+ credential_filepath (str):
2378+ The path to the credential file for GCP, AWS and SDV-Enterprise.
2379+ synthesizers (list[string]):
2380+ The synthesizer(s) to evaluate. Defaults to
2381+ ``[HMASynthesizer, MultiTableUniformSynthesizer]``.
2382+ sdv_datasets (list[str] or ``None``):
2383+ Names of the SDV demo datasets to use for the benchmark.
2384+ additional_datasets_folder (str or ``None``):
2385+ The path to an S3 bucket. Datasets found in this folder are
2386+ run in addition to the SDV datasets. If ``None``, no additional datasets are used.
2387+ limit_dataset_size (bool):
2388+ Use this flag to limit the size of the datasets for faster evaluation. If ``True``,
2389+ limit the size of every table to 1,000 rows (randomly sampled) and the first 10
2390+ columns.
2391+ compute_quality_score (bool):
2392+ Whether or not to evaluate an overall quality score. Defaults to ``True``.
2393+ compute_diagnostic_score (bool):
2394+ Whether or not to evaluate an overall diagnostic score. Defaults to ``True``.
2395+ compute_privacy_score (bool):
2396+ Whether or not to evaluate an overall privacy score. Defaults to ``True``.
2397+ sdmetrics (list[str]):
2398+ A list of the different SDMetrics to use.
2399+ If you'd like to input specific parameters into the metric, provide a tuple with
2400+ the metric name followed by a dictionary of the parameters.
2401+ timeout (int or ``None``):
2402+ The maximum number of seconds to wait for synthetic data creation. If ``None``, no
2403+ timeout is enforced.
2404+
2405+ Returns:
2406+ pandas.DataFrame:
2407+ A table containing one row per synthesizer + dataset + metric.
2408+ """
2409+ return _benchmark_compute_gcp (
22982410 output_destination = output_destination ,
2411+ credential_filepath = credential_filepath ,
22992412 synthesizers = synthesizers ,
2300- s3_client = s3_client ,
2301- job_args_list = job_args_list ,
2302- credentials = credentials ,
2413+ sdv_datasets = sdv_datasets ,
2414+ additional_datasets_folder = additional_datasets_folder ,
2415+ limit_dataset_size = limit_dataset_size ,
2416+ compute_quality_score = compute_quality_score ,
2417+ compute_diagnostic_score = compute_diagnostic_score ,
2418+ compute_privacy_score = compute_privacy_score ,
2419+ sdmetrics = sdmetrics ,
2420+ timeout = timeout ,
2421+ modality = 'multi_table' ,
23032422 )
0 commit comments