Skip to content

Commit b4e8813

Browse files
authored
Add workflow to run SDGym monthly and publish results (#427)
1 parent 1616a40 commit b4e8813

27 files changed

+1547
-152
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Run SDGym Benchmark
2+
3+
on:
4+
workflow_dispatch:
5+
schedule:
6+
- cron: '0 5 1 * *'
7+
8+
jobs:
9+
run-sdgym-benchmark:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
with:
14+
fetch-depth: 0
15+
- name: Set up latest Python
16+
uses: actions/setup-python@v5
17+
with:
18+
python-version-file: 'pyproject.toml'
19+
- name: Install dependencies
20+
run: |
21+
python -m pip install --upgrade pip
22+
python -m pip install -e .[dev]
23+
24+
- name: Run SDGym Benchmark
25+
env:
26+
SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
27+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
28+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
29+
AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }}
30+
31+
run: invoke run-sdgym-benchmark
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
name: Upload SDGym Benchmark results
2+
3+
on:
4+
workflow_run:
5+
workflows: ["Run SDGym Benchmark"]
6+
types:
7+
- completed
8+
workflow_dispatch:
9+
schedule:
10+
- cron: '0 6 * * *'
11+
12+
jobs:
13+
upload-sdgym-benchmark:
14+
runs-on: ubuntu-latest
15+
16+
steps:
17+
- uses: actions/checkout@v4
18+
with:
19+
fetch-depth: 0
20+
21+
- name: Set up latest Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version-file: 'pyproject.toml'
25+
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
python -m pip install -e .[dev]
30+
31+
- name: Upload SDGym Benchmark
32+
env:
33+
PYDRIVE_TOKEN: ${{ secrets.PYDRIVE_TOKEN }}
34+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
35+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
36+
GITHUB_LOCAL_RESULTS_DIR: ${{ runner.temp }}/sdgym-leaderboard-files
37+
run: |
38+
invoke upload-benchmark-results
39+
echo "GITHUB_LOCAL_RESULTS_DIR=$GITHUB_LOCAL_RESULTS_DIR" >> $GITHUB_ENV
40+
41+
- name: Prepare files for commit
42+
if: env.SKIP_UPLOAD != 'true'
43+
run: |
44+
mkdir pr-staging
45+
echo "Looking for files in: $GITHUB_LOCAL_RESULTS_DIR"
46+
ls -l "$GITHUB_LOCAL_RESULTS_DIR"
47+
for f in "$GITHUB_LOCAL_RESULTS_DIR"/*; do
48+
if [ -f "$f" ]; then
49+
base=$(basename "$f")
50+
cp "$f" "pr-staging/${base}"
51+
fi
52+
done
53+
54+
echo "Files staged for PR:"
55+
ls -l pr-staging
56+
57+
- name: Checkout target repo (sdv-dev.github.io)
58+
if: env.SKIP_UPLOAD != 'true'
59+
run: |
60+
git clone https://github.com/sdv-dev/sdv-dev.github.io.git target-repo
61+
cd target-repo
62+
git checkout gatsby-home
63+
64+
- name: Copy results and commit
65+
if: env.SKIP_UPLOAD != 'true'
66+
env:
67+
GH_TOKEN: ${{ secrets.GH_TOKEN }}
68+
FOLDER_NAME: ${{ env.FOLDER_NAME }}
69+
run: |
70+
cp pr-staging/* target-repo/assets/sdgym-leaderboard-files/
71+
cd target-repo
72+
git checkout gatsby-home
73+
git config --local user.name "github-actions[bot]"
74+
git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
75+
git add assets/
76+
git commit -m "Upload SDGym Benchmark Results ($FOLDER_NAME)" || echo "No changes to commit"
77+
git remote set-url origin https://x-access-token:${GH_TOKEN}@github.com/sdv-dev/sdv-dev.github.io.git
78+
git push origin gatsby-home
79+
COMMIT_HASH=$(git rev-parse HEAD)
80+
COMMIT_URL="https://github.com/sdv-dev/sdv-dev.github.io/commit/${COMMIT_HASH}"
81+
echo "Commit URL: $COMMIT_URL"
82+
echo "COMMIT_URL=$COMMIT_URL" >> $GITHUB_ENV
83+
84+
- name: Send Slack notification
85+
if: env.SKIP_UPLOAD != 'true'
86+
env:
87+
SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
88+
run: |
89+
invoke notify-sdgym-benchmark-uploaded \
90+
--folder-name "$FOLDER_NAME" \
91+
--commit-url "$COMMIT_URL"

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ test = [
8282
'pytest-cov>=2.6.0',
8383
'jupyter>=1.0.0,<2',
8484
'tomli>=2.0.0,<3',
85+
'slack-sdk>=3.23,<4.0',
86+
"openpyxl>=3.0.0; python_version<'3.9'",
87+
"openpyxl>=3.1.2; python_version>='3.9'",
88+
'pydrive2>=1.4.0,<2.0.0'
8589
]
8690
dev = [
8791
'sdgym[dask, test]',
@@ -195,6 +199,7 @@ exclude = [
195199
".ipynb_checkpoints",
196200
"tasks.py",
197201
"static_code_analysis.txt",
202+
"*.ipynb"
198203
]
199204

200205
[tool.ruff.lint]

sdgym/benchmark.py

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Main SDGym benchmarking module."""
22

3-
import base64
43
import concurrent
54
import logging
65
import math
@@ -9,6 +8,7 @@
98
import pickle
109
import re
1110
import textwrap
11+
import threading
1212
import tracemalloc
1313
import warnings
1414
from collections import defaultdict
@@ -24,6 +24,7 @@
2424
import numpy as np
2525
import pandas as pd
2626
import tqdm
27+
from botocore.config import Config
2728
from sdmetrics.reports.multi_table import (
2829
DiagnosticReport as MultiTableDiagnosticReport,
2930
)
@@ -42,9 +43,10 @@
4243
from sdgym.errors import SDGymError
4344
from sdgym.metrics import get_metrics
4445
from sdgym.progress import TqdmLogger, progress
45-
from sdgym.result_writer import LocalResultsWriter
46+
from sdgym.result_writer import LocalResultsWriter, S3ResultsWriter
4647
from sdgym.s3 import (
4748
S3_PREFIX,
49+
S3_REGION,
4850
is_s3_path,
4951
parse_s3_path,
5052
write_csv,
@@ -168,6 +170,11 @@ def _setup_output_destination_aws(output_destination, synthesizers, datasets, s3
168170
'run_id': f's3://{bucket_name}/{top_folder}/run_{today}_{increment}.yaml',
169171
}
170172

173+
s3_client.put_object(
174+
Bucket=bucket_name,
175+
Key=f'{top_folder}/run_{today}_{increment}.yaml',
176+
Body='completed_date: null\n'.encode('utf-8'),
177+
)
171178
return paths
172179

173180

@@ -236,11 +243,25 @@ def _generate_job_args_list(
236243
synthesizers = get_synthesizers(synthesizers + custom_synthesizers)
237244

238245
# Get list of dataset paths
239-
sdv_datasets = [] if sdv_datasets is None else get_dataset_paths(datasets=sdv_datasets)
246+
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
247+
aws_secret_access_key_key = os.getenv('AWS_SECRET_ACCESS_KEY')
248+
sdv_datasets = (
249+
[]
250+
if sdv_datasets is None
251+
else get_dataset_paths(
252+
datasets=sdv_datasets,
253+
aws_access_key_id=aws_access_key_id,
254+
aws_secret_access_key=aws_secret_access_key_key,
255+
)
256+
)
240257
additional_datasets = (
241258
[]
242259
if additional_datasets_folder is None
243-
else get_dataset_paths(bucket=additional_datasets_folder)
260+
else get_dataset_paths(
261+
bucket=additional_datasets_folder,
262+
aws_access_key_id=aws_access_key_id,
263+
aws_secret_access_key=aws_secret_access_key_key,
264+
)
244265
)
245266
datasets = sdv_datasets + additional_datasets
246267
synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers]
@@ -524,27 +545,36 @@ def _score_with_timeout(
524545
synthesizer_path=None,
525546
result_writer=None,
526547
):
548+
output = {} if isinstance(result_writer, S3ResultsWriter) else None
549+
args = (
550+
synthesizer,
551+
data,
552+
metadata,
553+
metrics,
554+
output,
555+
compute_quality_score,
556+
compute_diagnostic_score,
557+
compute_privacy_score,
558+
modality,
559+
dataset_name,
560+
synthesizer_path,
561+
result_writer,
562+
)
563+
if isinstance(result_writer, S3ResultsWriter):
564+
thread = threading.Thread(target=_score, args=args, daemon=True)
565+
thread.start()
566+
thread.join(timeout)
567+
if thread.is_alive():
568+
LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name)
569+
return {'timeout': True, 'error': 'Timeout'}
570+
571+
return output
572+
527573
with multiprocessing_context():
528574
with multiprocessing.Manager() as manager:
529575
output = manager.dict()
530-
process = multiprocessing.Process(
531-
target=_score,
532-
args=(
533-
synthesizer,
534-
data,
535-
metadata,
536-
metrics,
537-
output,
538-
compute_quality_score,
539-
compute_diagnostic_score,
540-
compute_privacy_score,
541-
modality,
542-
dataset_name,
543-
synthesizer_path,
544-
result_writer,
545-
),
546-
)
547-
576+
args = args[:4] + (output,) + args[5:] # replace output=None with manager.dict()
577+
process = multiprocessing.Process(target=_score, args=args)
548578
process.start()
549579
process.join(timeout)
550580
process.terminate()
@@ -697,7 +727,6 @@ def _run_job(args):
697727
compute_privacy_score,
698728
cache_dir,
699729
)
700-
701730
if synthesizer_path and result_writer:
702731
result_writer.write_dataframe(scores, synthesizer_path['benchmark_result'])
703732

@@ -998,9 +1027,10 @@ def _write_run_id_file(synthesizers, job_args_list, result_writer=None):
9981027
}
9991028
for synthesizer in synthesizers:
10001029
if synthesizer not in SDV_SINGLE_TABLE_SYNTHESIZERS:
1001-
ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY[synthesizer]
1002-
library_version = version(ext_lib)
1003-
metadata[f'{ext_lib}_version'] = library_version
1030+
ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY.get(synthesizer)
1031+
if ext_lib:
1032+
library_version = version(ext_lib)
1033+
metadata[f'{ext_lib}_version'] = library_version
10041034
elif 'sdv' not in metadata.keys():
10051035
metadata['sdv_version'] = version('sdv')
10061036

@@ -1180,20 +1210,22 @@ def _validate_aws_inputs(output_destination, aws_access_key_id, aws_secret_acces
11801210
if not output_destination.startswith('s3://'):
11811211
raise ValueError("'output_destination' must be an S3 URL starting with 's3://'. ")
11821212

1183-
parsed_url = urlparse(output_destination)
1184-
bucket_name = parsed_url.netloc
1213+
bucket_name, _ = parse_s3_path(output_destination)
11851214
if not bucket_name:
11861215
raise ValueError(f'Invalid S3 URL: {output_destination}')
11871216

1217+
config = Config(connect_timeout=30, read_timeout=300)
11881218
if aws_access_key_id and aws_secret_access_key:
11891219
s3_client = boto3.client(
11901220
's3',
11911221
aws_access_key_id=aws_access_key_id,
11921222
aws_secret_access_key=aws_secret_access_key,
1223+
region_name=S3_REGION,
1224+
config=config,
11931225
)
11941226
else:
11951227
# No credentials provided — rely on default session
1196-
s3_client = boto3.client('s3')
1228+
s3_client = boto3.client('s3', config=config)
11971229

11981230
s3_client.head_bucket(Bucket=bucket_name)
11991231
if not _check_write_permissions(s3_client, bucket_name):
@@ -1223,8 +1255,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client):
12231255
job_args_key = f'{path}{job_args_key}' if path else job_args_key
12241256

12251257
serialized_data = pickle.dumps(job_args_list)
1226-
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
1227-
s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=encoded_data)
1258+
s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=serialized_data)
12281259

12291260
return bucket_name, job_args_key
12301261

@@ -1235,15 +1266,6 @@ def _get_s3_script_content(
12351266
return f"""
12361267
import boto3
12371268
import pickle
1238-
import base64
1239-
import pandas as pd
1240-
import sdgym
1241-
from sdgym.synthesizers.sdv import (
1242-
CopulaGANSynthesizer, CTGANSynthesizer,
1243-
GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer,
1244-
SDVRelationalSynthesizer, SDVTabularSynthesizer, TVAESynthesizer
1245-
)
1246-
from sdgym.synthesizers import RealTabFormerSynthesizer
12471269
from sdgym.benchmark import _run_jobs, _write_run_id_file, _update_run_id_file
12481270
from io import StringIO
12491271
from sdgym.result_writer import S3ResultsWriter
@@ -1255,9 +1277,7 @@ def _get_s3_script_content(
12551277
region_name='{region_name}'
12561278
)
12571279
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
1258-
encoded_data = response['Body'].read().decode('utf-8')
1259-
serialized_data = base64.b64decode(encoded_data.encode('utf-8'))
1260-
job_args_list = pickle.loads(serialized_data)
1280+
job_args_list = pickle.loads(response['Body'].read())
12611281
result_writer = S3ResultsWriter(s3_client=s3_client)
12621282
_write_run_id_file({synthesizers}, job_args_list, result_writer)
12631283
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
@@ -1287,7 +1307,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content):
12871307
12881308
echo "======== Install Dependencies in venv ============"
12891309
pip install --upgrade pip
1290-
pip install "sdgym[all]"
1310+
pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-425-workflow-sdgym#egg=sdgym"
12911311
pip install s3fs
12921312
12931313
echo "======== Write Script ==========="
@@ -1313,11 +1333,10 @@ def _run_on_aws(
13131333
aws_secret_access_key,
13141334
):
13151335
bucket_name, job_args_key = _store_job_args_in_s3(output_destination, job_args_list, s3_client)
1316-
region_name = 'us-east-1'
13171336
script_content = _get_s3_script_content(
13181337
aws_access_key_id,
13191338
aws_secret_access_key,
1320-
region_name,
1339+
S3_REGION,
13211340
bucket_name,
13221341
job_args_key,
13231342
synthesizers,
@@ -1327,12 +1346,12 @@ def _run_on_aws(
13271346
session = boto3.session.Session(
13281347
aws_access_key_id=aws_access_key_id,
13291348
aws_secret_access_key=aws_secret_access_key,
1330-
region_name=region_name,
1349+
region_name=S3_REGION,
13311350
)
13321351
ec2_client = session.client('ec2')
13331352
print(f'This instance is being created in region: {session.region_name}') # noqa
13341353
user_data_script = _get_user_data_script(
1335-
aws_access_key_id, aws_secret_access_key, region_name, script_content
1354+
aws_access_key_id, aws_secret_access_key, S3_REGION, script_content
13361355
)
13371356
response = ec2_client.run_instances(
13381357
ImageId='ami-080e1f13689e07408',

0 commit comments

Comments
 (0)