Skip to content

Commit cf576fb

Browse files
committed
update
1 parent 8f0f319 commit cf576fb

File tree

1 file changed

+211
-92
lines changed

1 file changed

+211
-92
lines changed

sdgym/benchmark.py

Lines changed: 211 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@
8181
'token_uri',
8282
'auth_provider_x509_cert_url',
8383
'client_x509_cert_url',
84+
'universe_domain',
85+
'gcp_project',
86+
'gcp_zone',
8487
},
85-
'sdv': {'username', 'sdv_license_key'},
88+
'sdv': {'username', 'license_key'},
8689
}
8790
LOGGER = logging.getLogger(__name__)
8891
DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [
@@ -1628,66 +1631,80 @@ def _get_s3_script_content(
16281631
"""
16291632

16301633

1631-
def _get_user_data_script(access_key, secret_key, region_name, script_content):
1632-
return textwrap.dedent(f"""\
1633-
#!/bin/bash
1634-
set -e
1635-
1636-
# Always terminate the instance when the script exits (success or failure)
1637-
trap '
1638-
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id);
1639-
echo "======== Terminating EC2 instance: $INSTANCE_ID ==========";
1640-
aws ec2 terminate-instances --instance-ids $INSTANCE_ID;
1641-
' EXIT
1642-
1643-
exec > >(tee /var/log/user-data.log|logger -t user-data -s 2>/dev/console) 2>&1
1644-
echo "======== Update and Install Dependencies ============"
1645-
sudo apt update -y
1646-
sudo apt install -y python3-pip python3-venv awscli
1647-
echo "======== Configure AWS CLI ============"
1648-
aws configure set aws_access_key_id '{access_key}'
1649-
aws configure set aws_secret_access_key '{secret_key}'
1650-
aws configure set default.region '{region_name}'
1651-
1652-
echo "======== Create Virtual Environment ============"
1653-
python3 -m venv ~/env
1654-
source ~/env/bin/activate
1655-
1656-
echo "======== Install Dependencies in venv ============"
1657-
pip install --upgrade pip
1658-
pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@feature_branch/mutli_table_benchmark"
1659-
pip install s3fs
1660-
1661-
echo "======== Write Script ==========="
1662-
cat << 'EOF' > ~/sdgym_script.py
1663-
{script_content}
1664-
EOF
1665-
1666-
echo "======== Run Script ==========="
1667-
python ~/sdgym_script.py
1668-
echo "======== Complete ==========="
1669-
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
1670-
aws ec2 terminate-instances --instance-ids $INSTANCE_ID
1671-
""").strip()
1634+
def _get_user_data_script(credentials, script_content, compute_service='aws'):
1635+
"""Generate user-data for either AWS or GCP.
16721636
1637+
Args:
1638+
credentials: dict with AWS and optional SDV credentials.
1639+
script_content: Python script to write and execute.
1640+
compute_service: "aws" or "gcp"
1641+
"""
1642+
aws_key = credentials['aws']['aws_access_key_id']
1643+
aws_secret = credentials['aws']['aws_secret_access_key']
1644+
1645+
# Conditional bundle-xsynthesizers install
1646+
sdv_creds = credentials.get('sdv', {})
1647+
bundle_install = ''
1648+
if sdv_creds.get('username') and sdv_creds.get('license_key'):
1649+
bundle_install = (
1650+
f'pip install bundle-xsynthesizers '
1651+
f'--index-url https://{sdv_creds["username"]}:{sdv_creds["license_key"]}@pypi.datacebo.com'
1652+
)
16731653

1674-
def _get_gcp_script(credentials, script_content):
1654+
# --- Build termination trap depending on compute service ---
1655+
if compute_service == 'aws':
1656+
termination_trap = textwrap.dedent("""\
1657+
# AWS termination
1658+
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id || true)
1659+
if [ ! -z "$INSTANCE_ID" ]; then
1660+
echo "Terminating AWS EC2 instance: $INSTANCE_ID"
1661+
aws ec2 terminate-instances --instance-ids $INSTANCE_ID || true
1662+
fi
1663+
""")
1664+
1665+
elif compute_service == 'gcp':
1666+
termination_trap = textwrap.dedent("""\
1667+
# GCP termination via Compute API (no gcloud required)
1668+
echo "Detected GCP environment — terminating instance via Compute API"
1669+
1670+
TOKEN=$(curl -s -H "Metadata-Flavor: Google" \
1671+
http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token \
1672+
| python3 -c "import sys, json; print(json.load(sys.stdin)['access_token'])")
1673+
1674+
PROJECT_ID=$(curl -s -H "Metadata-Flavor: Google" \
1675+
http://169.254.169.254/computeMetadata/v1/project/project-id)
1676+
1677+
ZONE=$(curl -s -H "Metadata-Flavor: Google" \
1678+
http://169.254.169.254/computeMetadata/v1/instance/zone | awk -F/ '{print $4}')
1679+
1680+
curl -s -X DELETE \
1681+
-H "Authorization: Bearer $TOKEN" \
1682+
-H "Content-Type: application/json" \
1683+
"https://compute.googleapis.com/compute/v1/projects/$PROJECT_ID/zones/$ZONE/" \
1684+
"instances/$HOSTNAME" \
1685+
|| true
1686+
""")
1687+
1688+
# --- Final script assembly ---
16751689
return textwrap.dedent(f"""\
16761690
#!/bin/bash
16771691
set -e
16781692
1679-
# Always terminate the instance when the script exits (success or failure)
1693+
# Auto-termination trap
16801694
trap '
1681-
gcloud compute instances delete "$HOSTNAME" --zone=us-central1-a --quiet
1695+
echo "======== Auto-Termination Triggered =========="
1696+
{termination_trap}
16821697
' EXIT
16831698
1684-
exec > >(tee /var/log/user-data.log|logger -t user-data -s 2>/dev/console) 2>&1
1699+
exec > >(tee /var/log/user-data.log | logger -t user-data -s 2>/dev/console) 2>&1
1700+
16851701
echo "======== Update and Install Dependencies ============"
16861702
sudo apt update -y
1687-
sudo apt install -y python3-pip python3-venv awscli git
1688-
echo "======== Configure AWS CLI ============"
1689-
aws configure set aws_access_key_id '{credentials['aws']['aws_access_key_id']}'
1690-
aws configure set aws_secret_access_key '{credentials['aws']['aws_secret_access_key']}'
1703+
sudo apt install -y python3-pip python3-venv awscli git jq
1704+
1705+
echo "======== Configure AWS CLI (always needed for S3 output) ============"
1706+
aws configure set aws_access_key_id '{aws_key}'
1707+
aws configure set aws_secret_access_key '{aws_secret}'
16911708
aws configure set default.region '{S3_REGION}'
16921709
16931710
echo "======== Create Virtual Environment ============"
@@ -1696,8 +1713,9 @@ def _get_gcp_script(credentials, script_content):
16961713
16971714
echo "======== Install Dependencies in venv ============"
16981715
pip install --upgrade pip
1699-
pip install bundle-xsynthesizers --index-url https://{credentials['sdv']['username']}:{credentials['sdv']['license_key']}@pypi.datacebo.com
1700-
pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@gcp_benchmark-romain"
1716+
{bundle_install}
1717+
pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@gcp-benchmark-romain"
1718+
pip install s3fs
17011719
17021720
echo "======== Write Script ==========="
17031721
cat << 'EOF' > ~/sdgym_script.py
@@ -1706,8 +1724,8 @@ def _get_gcp_script(credentials, script_content):
17061724
17071725
echo "======== Run Script ==========="
17081726
python ~/sdgym_script.py
1727+
17091728
echo "======== Complete ==========="
1710-
gcloud compute instances delete "$HOSTNAME" --zone=us-central1-a --quiet
17111729
""").strip()
17121730

17131731

@@ -1730,9 +1748,10 @@ def _run_on_gcp(output_destination, synthesizers, s3_client, job_args_list, cred
17301748

17311749
machine_type = f'zones/{gcp_zone}/machineTypes/e2-standard-8'
17321750
source_disk_image = 'projects/debian-cloud/global/images/family/debian-12'
1733-
startup_script = _get_gcp_script(
1734-
credentials=gcp_credentials,
1751+
startup_script = _get_user_data_script(
1752+
credentials=credentials,
17351753
script_content=script_content,
1754+
compute_service='gcp',
17361755
)
17371756

17381757
instance_client = compute_v1.InstancesClient(credentials=gcp_credentials)
@@ -2182,8 +2201,7 @@ def _get_credentials(credential_filepath):
21822201
f'Found: {actual_sections}.'
21832202
)
21842203

2185-
for section, keys in CREDENTIAL_KEYS.items():
2186-
expected_keys = set(keys.keys())
2204+
for section, expected_keys in CREDENTIAL_KEYS.items():
21872205
actual_keys = set(credentials[section].keys())
21882206
if expected_keys != actual_keys:
21892207
raise ValueError(
@@ -2194,6 +2212,72 @@ def _get_credentials(credential_filepath):
21942212
return credentials
21952213

21962214

2215+
def _benchmark_compute_gcp(
2216+
output_destination,
2217+
credential_filepath,
2218+
synthesizers,
2219+
sdv_datasets,
2220+
additional_datasets_folder,
2221+
limit_dataset_size,
2222+
compute_quality_score,
2223+
compute_diagnostic_score,
2224+
compute_privacy_score,
2225+
sdmetrics,
2226+
timeout,
2227+
modality,
2228+
):
2229+
"""Run the SDGym benchmark on datasets for the given modality."""
2230+
credentials = _get_credentials(credential_filepath)
2231+
s3_client = _validate_output_destination(
2232+
output_destination,
2233+
aws_keys={
2234+
'aws_access_key_id': credentials['aws']['aws_access_key_id'],
2235+
'aws_secret_access_key': credentials['aws']['aws_secret_access_key'],
2236+
},
2237+
)
2238+
2239+
if not synthesizers:
2240+
synthesizers = []
2241+
2242+
_ensure_uniform_included(synthesizers, modality)
2243+
synthesizers = _import_and_validate_synthesizers(
2244+
synthesizers=synthesizers,
2245+
custom_synthesizers=None,
2246+
modality=modality,
2247+
)
2248+
2249+
job_args_list = _generate_job_args_list(
2250+
limit_dataset_size=limit_dataset_size,
2251+
sdv_datasets=sdv_datasets,
2252+
additional_datasets_folder=additional_datasets_folder,
2253+
sdmetrics=sdmetrics,
2254+
timeout=timeout,
2255+
output_destination=output_destination,
2256+
compute_quality_score=compute_quality_score,
2257+
compute_diagnostic_score=compute_diagnostic_score,
2258+
compute_privacy_score=compute_privacy_score,
2259+
synthesizers=synthesizers,
2260+
detailed_results_folder=None,
2261+
s3_client=s3_client,
2262+
modality=modality,
2263+
)
2264+
if not job_args_list:
2265+
return _get_empty_dataframe(
2266+
compute_diagnostic_score=compute_diagnostic_score,
2267+
compute_quality_score=compute_quality_score,
2268+
compute_privacy_score=compute_privacy_score,
2269+
sdmetrics=sdmetrics,
2270+
)
2271+
2272+
_run_on_gcp(
2273+
output_destination=output_destination,
2274+
synthesizers=synthesizers,
2275+
s3_client=s3_client,
2276+
job_args_list=job_args_list,
2277+
credentials=credentials,
2278+
)
2279+
2280+
21972281
def _benchmark_single_table_compute_gcp(
21982282
output_destination,
21992283
credential_filepath,
@@ -2254,50 +2338,85 @@ def _benchmark_single_table_compute_gcp(
22542338
pandas.DataFrame:
22552339
A table containing one row per synthesizer + dataset + metric.
22562340
"""
2257-
credentials = _get_credentials(credential_filepath)
2258-
s3_client = _validate_output_destination(
2259-
output_destination,
2260-
aws_keys={
2261-
'aws_access_key_id': credentials['aws']['aws_access_key_id'],
2262-
'aws_secret_access_key': credentials['aws']['aws_secret_access_key'],
2263-
},
2264-
)
2265-
if not synthesizers:
2266-
synthesizers = []
2267-
2268-
_ensure_uniform_included(synthesizers, 'single_table')
2269-
synthesizers = _import_and_validate_synthesizers(
2341+
return _benchmark_compute_gcp(
2342+
output_destination=output_destination,
2343+
credential_filepath=credential_filepath,
22702344
synthesizers=synthesizers,
2271-
custom_synthesizers=None,
2272-
modality='single_table',
2273-
)
2274-
job_args_list = _generate_job_args_list(
2275-
limit_dataset_size=limit_dataset_size,
22762345
sdv_datasets=sdv_datasets,
22772346
additional_datasets_folder=additional_datasets_folder,
2278-
sdmetrics=sdmetrics,
2279-
timeout=timeout,
2280-
output_destination=output_destination,
2347+
limit_dataset_size=limit_dataset_size,
22812348
compute_quality_score=compute_quality_score,
22822349
compute_diagnostic_score=compute_diagnostic_score,
22832350
compute_privacy_score=compute_privacy_score,
2284-
synthesizers=synthesizers,
2285-
detailed_results_folder=None,
2286-
s3_client=s3_client,
2351+
sdmetrics=sdmetrics,
2352+
timeout=timeout,
22872353
modality='single_table',
22882354
)
2289-
if not job_args_list:
2290-
return _get_empty_dataframe(
2291-
compute_diagnostic_score=compute_diagnostic_score,
2292-
compute_quality_score=compute_quality_score,
2293-
compute_privacy_score=compute_privacy_score,
2294-
sdmetrics=sdmetrics,
2295-
)
22962355

2297-
_run_on_gcp(
2356+
2357+
def _benchmark_multi_table_compute_gcp(
2358+
output_destination,
2359+
credential_filepath,
2360+
synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
2361+
sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
2362+
additional_datasets_folder=None,
2363+
limit_dataset_size=False,
2364+
compute_quality_score=True,
2365+
compute_diagnostic_score=True,
2366+
compute_privacy_score=True,
2367+
sdmetrics=None,
2368+
timeout=None,
2369+
):
2370+
"""Run the SDGym benchmark on multi-table datasets.
2371+
2372+
Args:
2373+
output_destination (str):
2374+
An S3 bucket or filepath. The results output folder will be written here.
2375+
Should be structured as:
2376+
s3://{s3_bucket_name}/{path_to_file} or s3://{s3_bucket_name}.
2377+
credential_filepath (str):
2378+
The path to the credential file for GCP, AWS and SDV-Enterprise.
2379+
synthesizers (list[string]):
2380+
The synthesizer(s) to evaluate. Defaults to
2381+
``[HMASynthesizer, MultiTableUniformSynthesizer]``.
2382+
sdv_datasets (list[str] or ``None``):
2383+
Names of the SDV demo datasets to use for the benchmark.
2384+
additional_datasets_folder (str or ``None``):
2385+
The path to an S3 bucket. Datasets found in this folder are
2386+
run in addition to the SDV datasets. If ``None``, no additional datasets are used.
2387+
limit_dataset_size (bool):
2388+
Use this flag to limit the size of the datasets for faster evaluation. If ``True``,
2389+
limit the size of every table to 1,000 rows (randomly sampled) and the first 10
2390+
columns.
2391+
compute_quality_score (bool):
2392+
Whether or not to evaluate an overall quality score. Defaults to ``True``.
2393+
compute_diagnostic_score (bool):
2394+
Whether or not to evaluate an overall diagnostic score. Defaults to ``True``.
2395+
compute_privacy_score (bool):
2396+
Whether or not to evaluate an overall privacy score. Defaults to ``True``.
2397+
sdmetrics (list[str]):
2398+
A list of the different SDMetrics to use.
2399+
If you'd like to input specific parameters into the metric, provide a tuple with
2400+
the metric name followed by a dictionary of the parameters.
2401+
timeout (int or ``None``):
2402+
The maximum number of seconds to wait for synthetic data creation. If ``None``, no
2403+
timeout is enforced.
2404+
2405+
Returns:
2406+
pandas.DataFrame:
2407+
A table containing one row per synthesizer + dataset + metric.
2408+
"""
2409+
return _benchmark_compute_gcp(
22982410
output_destination=output_destination,
2411+
credential_filepath=credential_filepath,
22992412
synthesizers=synthesizers,
2300-
s3_client=s3_client,
2301-
job_args_list=job_args_list,
2302-
credentials=credentials,
2413+
sdv_datasets=sdv_datasets,
2414+
additional_datasets_folder=additional_datasets_folder,
2415+
limit_dataset_size=limit_dataset_size,
2416+
compute_quality_score=compute_quality_score,
2417+
compute_diagnostic_score=compute_diagnostic_score,
2418+
compute_privacy_score=compute_privacy_score,
2419+
sdmetrics=sdmetrics,
2420+
timeout=timeout,
2421+
modality='multi_table',
23032422
)

0 commit comments

Comments
 (0)