Skip to content

Commit 051dfa5

Browse files
committed
update run_benchmark
1 parent 1228879 commit 051dfa5

File tree

6 files changed

+108
-33
lines changed

6 files changed

+108
-33
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Run SDGym Benchmark
2+
3+
on:
4+
workflow_dispatch:
5+
schedule:
6+
- cron: '0 5 1 * *'
7+
8+
jobs:
9+
run-sdgym-benchmark:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
with:
14+
fetch-depth: 0
15+
- name: Set up latest Python
16+
uses: actions/setup-python@v5
17+
with:
18+
python-version-file: 'pyproject.toml'
19+
- name: Install dependencies
20+
env:
21+
username: ${{ secrets.GCP_USERNAME }}
22+
license_key: ${{ secrets.GCP_LICENSE_KEY }}
23+
run: |
24+
python -m pip install --upgrade pip
25+
python -m pip install bundle-xsynthesizers --index-url https://{username}:{license_key}@pypi.datacebo.com
26+
python -m pip install --no-cache-dir -e .[dev]
27+
28+
- name: Run SDGym Benchmark
29+
env:
30+
SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
31+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
32+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
33+
AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }}
34+
35+
run: invoke run-sdgym-benchmark --modality multi_table

sdgym/benchmark.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,31 +1654,38 @@ def _get_user_data_script(credentials, script_content, compute_service='aws'):
16541654
else:
16551655
termination_body = """
16561656
echo "Terminating GCP instance via Compute API"
1657+
PROJECT_ID="$GCP_PROJECT"
1658+
ZONE="$GCP_ZONE"
1659+
INSTANCE="$INSTANCE_NAME"
16571660
1658-
TOKEN=$(curl -s -H "Metadata-Flavor: Google" \
1661+
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
16591662
http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token \
16601663
| jq -r ".access_token" || true)
16611664
1662-
PROJECT_ID=$(curl -s -H "Metadata-Flavor: Google" \
1663-
http://169.254.169.254/computeMetadata/v1/project/project-id || true)
1665+
echo "GCP termination parameters:"
1666+
echo " PROJECT_ID=$PROJECT_ID"
1667+
echo " ZONE=$ZONE"
1668+
echo " INSTANCE=$INSTANCE"
16641669
1665-
ZONE=$(curl -s -H "Metadata-Flavor: Google" \
1666-
http://169.254.169.254/computeMetadata/v1/instance/zone | awk -F/ '{print $4}' || true)
1667-
1668-
if [ -n "$TOKEN" ] && [ -n "$PROJECT_ID" ] && [ -n "$ZONE" ]; then
1670+
if [ -z "$PROJECT_ID" ] || [ -z "$ZONE" ] || [ -z "$INSTANCE" ] || [ -z "$TOKEN" ]; then
1671+
echo "Skipping GCP termination (missing required parameters)"
1672+
else
16691673
curl -s -X DELETE \
16701674
-H "Authorization: Bearer $TOKEN" \
16711675
"https://compute.googleapis.com/compute/v1/projects/" \
16721676
"$PROJECT_ID/zones/$ZONE/instances/" \
1673-
"$HOSTNAME" \
1677+
"$INSTANCE" \
16741678
|| true
16751679
fi
1676-
"""
1677-
1680+
"""
16781681
return textwrap.dedent(f"""\
16791682
#!/bin/bash
16801683
set -e
16811684
1685+
export GCP_PROJECT='{credentials['gcp']['gcp_project']}'
1686+
export GCP_ZONE='{credentials['gcp']['gcp_zone']}'
1687+
export INSTANCE_NAME='sdgym-run'
1688+
16821689
terminate_instance() {{
16831690
{termination_body}
16841691
}}

sdgym/run_benchmark/run_benchmark.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
11
"""Script to run a benchmark and upload results to S3."""
22

3+
import argparse
34
import json
45
import os
56
from datetime import datetime, timezone
67

78
from botocore.exceptions import ClientError
89

9-
from sdgym.benchmark import benchmark_single_table_aws
10+
from sdgym.benchmark import _benchmark_multi_table_compute_gcp, benchmark_single_table_aws
1011
from sdgym.run_benchmark.utils import (
1112
KEY_DATE_FILE,
1213
OUTPUT_DESTINATION_AWS,
13-
SYNTHESIZERS_SPLIT,
14+
SYNTHESIZERS_SPLIT_MULTI_TABLE,
15+
SYNTHESIZERS_SPLIT_SINGLE_TABLE,
1416
get_result_folder_name,
1517
post_benchmark_launch_message,
1618
)
1719
from sdgym.s3 import get_s3_client, parse_s3_path
1820

1921

20-
def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
22+
def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str, modality='single_table'):
2123
"""Append a new benchmark run to the benchmark dates file in S3."""
2224
s3_client = get_s3_client(
2325
aws_access_key_id=aws_access_key_id,
2426
aws_secret_access_key=aws_secret_access_key,
2527
)
2628
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
2729
try:
28-
object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{KEY_DATE_FILE}')
30+
object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{modality}{KEY_DATE_FILE}')
2931
body = object['Body'].read().decode('utf-8')
3032
data = json.loads(body)
3133
except ClientError as e:
@@ -41,23 +43,50 @@ def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
4143
)
4244

4345

46+
def _parse_args():
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument(
49+
'--modality',
50+
choices=['single_table', 'multi_table'],
51+
default='single_table',
52+
help='Benchmark modality to run.',
53+
)
54+
return parser.parse_args()
55+
56+
4457
def main():
4558
"""Main function to run the benchmark and upload results."""
59+
args = _parse_args()
4660
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
4761
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
4862
date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
49-
for synthesizer_group in SYNTHESIZERS_SPLIT:
50-
benchmark_single_table_aws(
51-
output_destination=OUTPUT_DESTINATION_AWS,
52-
aws_access_key_id=aws_access_key_id,
53-
aws_secret_access_key=aws_secret_access_key,
54-
synthesizers=synthesizer_group,
55-
compute_privacy_score=False,
56-
timeout=345600, # 4 days
57-
)
58-
59-
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str)
60-
post_benchmark_launch_message(date_str)
63+
64+
if args.modality == 'single_table':
65+
for synthesizer_group in SYNTHESIZERS_SPLIT_SINGLE_TABLE:
66+
benchmark_single_table_aws(
67+
output_destination=OUTPUT_DESTINATION_AWS,
68+
aws_access_key_id=aws_access_key_id,
69+
aws_secret_access_key=aws_secret_access_key,
70+
synthesizers=synthesizer_group,
71+
compute_privacy_score=False,
72+
timeout=345600, # 4 days
73+
)
74+
75+
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str, modality='single_table')
76+
77+
else:
78+
for synthesizer_group in SYNTHESIZERS_SPLIT_MULTI_TABLE:
79+
_benchmark_multi_table_compute_gcp(
80+
output_destination='s3://sdgym-benchmark/Debug/GCP/',
81+
aws_access_key_id=aws_access_key_id,
82+
aws_secret_access_key=aws_secret_access_key,
83+
synthesizers=synthesizer_group,
84+
compute_privacy_score=False,
85+
timeout=345600, # 4 days
86+
)
87+
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str, modality='multi_table')
88+
89+
post_benchmark_launch_message(date_str, compute_service='GCP')
6190

6291

6392
if __name__ == '__main__':

sdgym/run_benchmark/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,16 @@
4848
]
4949

5050
# The synthesizers inside the same list will be run by the same ec2 instance
51-
SYNTHESIZERS_SPLIT = [
51+
SYNTHESIZERS_SPLIT_SINGLE_TABLE = [
5252
['UniformSynthesizer', 'ColumnSynthesizer', 'GaussianCopulaSynthesizer', 'TVAESynthesizer'],
5353
['CopulaGANSynthesizer'],
5454
['CTGANSynthesizer'],
5555
['RealTabFormerSynthesizer'],
5656
]
57+
SYNTHESIZERS_SPLIT_MULTI_TABLE = [
58+
['HMASynthesizer'],
59+
['HSASynthesizer', 'IndependentSynthesizer', 'MultiTableUniformSynthesizer'],
60+
]
5761

5862

5963
def get_result_folder_name(date_str):
@@ -91,13 +95,13 @@ def post_slack_message(channel, text):
9195
client.chat_postMessage(channel=channel, text=text)
9296

9397

94-
def post_benchmark_launch_message(date_str):
98+
def post_benchmark_launch_message(date_str, compute_service='AWS'):
9599
"""Post a message to the SDV Alerts Slack channel when the benchmark is launched."""
96100
channel = SLACK_CHANNEL
97101
folder_name = get_result_folder_name(date_str)
98102
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
99103
url_link = get_s3_console_link(bucket, f'{prefix}{folder_name}/')
100-
body = '🏃 SDGym benchmark has been launched! EC2 Instances are running. '
104+
body = f'🏃 SDGym benchmark has been launched on {compute_service}! '
101105
body += f'Intermediate results can be found <{url_link}|here>.\n'
102106
post_slack_message(channel, body)
103107

tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def rmdir(c, path):
203203
pass
204204

205205
@task
206-
def run_sdgym_benchmark(c):
206+
def run_sdgym_benchmark(c, modality='single_table'):
207207
"""Run the SDGym benchmark."""
208-
c.run('python sdgym/run_benchmark/run_benchmark.py')
208+
c.run(f'python sdgym/run_benchmark/run_benchmark.py --modality {modality}')
209209

210210
@task
211211
def upload_benchmark_results(c):

tests/unit/run_benchmark/test_run_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from botocore.exceptions import ClientError
66

77
from sdgym.run_benchmark.run_benchmark import append_benchmark_run, main
8-
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT
8+
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT_SINGLE_TABLE
99

1010

1111
@patch('sdgym.run_benchmark.run_benchmark.get_s3_client')
@@ -122,7 +122,7 @@ def test_main(
122122
mock_getenv.assert_any_call('AWS_ACCESS_KEY_ID')
123123
mock_getenv.assert_any_call('AWS_SECRET_ACCESS_KEY')
124124
expected_calls = []
125-
for synthesizer in SYNTHESIZERS_SPLIT:
125+
for synthesizer in SYNTHESIZERS_SPLIT_SINGLE_TABLE:
126126
expected_calls.append(
127127
call(
128128
output_destination=OUTPUT_DESTINATION_AWS,

0 commit comments

Comments
 (0)