Skip to content

Commit 53bc640

Browse files
committed
cleaning 515
1 parent d932c6f commit 53bc640

File tree

6 files changed

+21
-172
lines changed

6 files changed

+21
-172
lines changed

.github/workflows/run_benchmark_multi_table.yml

Lines changed: 0 additions & 43 deletions
This file was deleted.

sdgym/_benchmark/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
),
2121
'gpu_type': 'nvidia-tesla-t4',
2222
'gpu_count': 1,
23-
'install_nvidia_driver': False, # DLVM already has drivers/tooling
23+
'install_nvidia_driver': False,
2424
'delete_on_success': True,
2525
'delete_on_error': True,
2626
'stop_fallback': True,

sdgym/run_benchmark/run_benchmark.py

Lines changed: 13 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,31 @@
11
"""Script to run a benchmark and upload results to S3."""
22

3-
import argparse
4-
import base64
53
import json
64
import os
75
from datetime import datetime, timezone
8-
from pathlib import Path
96

107
from botocore.exceptions import ClientError
118

12-
from sdgym._benchmark.benchmark import _benchmark_multi_table_compute_gcp
139
from sdgym.benchmark import benchmark_single_table_aws
1410
from sdgym.run_benchmark.utils import (
15-
GCP_PROJECT,
16-
GCP_ZONE,
1711
KEY_DATE_FILE,
1812
OUTPUT_DESTINATION_AWS,
19-
SYNTHESIZERS_SPLIT_MULTI_TABLE,
20-
SYNTHESIZERS_SPLIT_SINGLE_TABLE,
13+
SYNTHESIZERS_SPLIT,
2114
get_result_folder_name,
2215
post_benchmark_launch_message,
2316
)
2417
from sdgym.s3 import get_s3_client, parse_s3_path
2518

2619

27-
def append_benchmark_run(
28-
aws_access_key_id, aws_secret_access_key, date_str, modality='single_table'
29-
):
20+
def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
3021
"""Append a new benchmark run to the benchmark dates file in S3."""
3122
s3_client = get_s3_client(
3223
aws_access_key_id=aws_access_key_id,
3324
aws_secret_access_key=aws_secret_access_key,
3425
)
3526
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
3627
try:
37-
object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{modality}{KEY_DATE_FILE}')
28+
object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{KEY_DATE_FILE}')
3829
body = object['Body'].read().decode('utf-8')
3930
data = json.loads(body)
4031
except ClientError as e:
@@ -50,116 +41,23 @@ def append_benchmark_run(
5041
)
5142

5243

53-
def _load_gcp_service_account_from_env():
54-
"""Load GCP service account JSON from env.
55-
56-
Supports:
57-
- raw JSON string
58-
- base64-encoded JSON string
59-
"""
60-
raw = os.getenv('GCP_SERVICE_ACCOUNT_JSON', '') or ''
61-
if not raw.strip():
62-
return {}
63-
64-
try:
65-
return json.loads(raw)
66-
except json.JSONDecodeError:
67-
decoded = base64.b64decode(raw).decode('utf-8')
68-
return json.loads(decoded)
69-
70-
71-
def create_credentials_file(filepath):
72-
"""Create credentials file used by the benchmark launcher."""
73-
gcp_sa = _load_gcp_service_account_from_env()
74-
75-
credentials = {
76-
'aws': {
77-
'aws_access_key_id': os.getenv('AWS_ACCESS_KEY_ID'),
78-
'aws_secret_access_key': os.getenv('AWS_SECRET_ACCESS_KEY'),
79-
},
80-
'gcp': {
81-
**gcp_sa,
82-
'gcp_project': GCP_PROJECT,
83-
'gcp_zone': GCP_ZONE,
84-
},
85-
'sdv': {
86-
'username': os.getenv('SDV_ENTERPRISE_USERNAME'),
87-
'license_key': os.getenv('SDV_ENTERPRISE_LICENSE_KEY'),
88-
},
89-
}
90-
91-
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
92-
with open(filepath, 'w', encoding='utf-8') as f:
93-
json.dump(credentials, f, indent=2, sort_keys=True)
94-
f.write('\n')
95-
96-
97-
def _parse_args():
98-
parser = argparse.ArgumentParser()
99-
parser.add_argument(
100-
'--modality',
101-
choices=['single_table', 'multi_table'],
102-
default='single_table',
103-
help='Benchmark modality to run.',
104-
)
105-
parser.add_argument(
106-
'--gcp-output-destination',
107-
default='s3://sdgym-benchmark/Debug/GCP/',
108-
help='Where to store GCP benchmark results (S3).',
109-
)
110-
return parser.parse_args()
111-
112-
11344
def main():
11445
"""Main function to run the benchmark and upload results."""
115-
args = _parse_args()
116-
11746
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
11847
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
11948
date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
120-
121-
if args.modality == 'single_table':
122-
for synthesizer_group in SYNTHESIZERS_SPLIT_SINGLE_TABLE:
123-
benchmark_single_table_aws(
124-
output_destination=OUTPUT_DESTINATION_AWS,
125-
aws_access_key_id=aws_access_key_id,
126-
aws_secret_access_key=aws_secret_access_key,
127-
synthesizers=synthesizer_group,
128-
compute_privacy_score=False,
129-
timeout=345600, # 4 days
130-
)
131-
132-
append_benchmark_run(
133-
aws_access_key_id,
134-
aws_secret_access_key,
135-
date_str,
136-
modality='single_table',
137-
)
138-
compute_service = 'AWS'
139-
140-
else:
141-
runner_temp = os.environ.get('RUNNER_TEMP', '/tmp')
142-
cred_path = os.path.join(runner_temp, 'credentials.json')
143-
create_credentials_file(cred_path)
144-
145-
for synthesizer_group in SYNTHESIZERS_SPLIT_MULTI_TABLE:
146-
_benchmark_multi_table_compute_gcp(
147-
output_destination=args.gcp_output_destination,
148-
credential_filepath=cred_path,
149-
synthesizers=synthesizer_group,
150-
compute_privacy_score=False,
151-
timeout=345600, # 4 days
152-
)
153-
154-
append_benchmark_run(
155-
aws_access_key_id,
156-
aws_secret_access_key,
157-
date_str,
158-
modality='multi_table',
49+
for synthesizer_group in SYNTHESIZERS_SPLIT:
50+
benchmark_single_table_aws(
51+
output_destination=OUTPUT_DESTINATION_AWS,
52+
aws_access_key_id=aws_access_key_id,
53+
aws_secret_access_key=aws_secret_access_key,
54+
synthesizers=synthesizer_group,
55+
compute_privacy_score=False,
56+
timeout=345600, # 4 days
15957
)
160-
compute_service = 'GCP'
16158

162-
post_benchmark_launch_message(date_str, compute_service=compute_service)
59+
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str)
60+
post_benchmark_launch_message(date_str)
16361

16462

16563
if __name__ == '__main__':

sdgym/run_benchmark/utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
from sdgym.s3 import parse_s3_path
1111

12-
GCP_ZONE = 'us-central1-a'
13-
GCP_PROJECT = 'sdgym-337614'
1412
OUTPUT_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/'
1513
UPLOAD_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/'
1614
DEBUG_SLACK_CHANNEL = 'sdv-alerts-debug'
@@ -50,16 +48,12 @@
5048
]
5149

5250
# The synthesizers inside the same list will be run by the same ec2 instance
53-
SYNTHESIZERS_SPLIT_SINGLE_TABLE = [
51+
SYNTHESIZERS_SPLIT = [
5452
['UniformSynthesizer', 'ColumnSynthesizer', 'GaussianCopulaSynthesizer', 'TVAESynthesizer'],
5553
['CopulaGANSynthesizer'],
5654
['CTGANSynthesizer'],
5755
['RealTabFormerSynthesizer'],
5856
]
59-
SYNTHESIZERS_SPLIT_MULTI_TABLE = [
60-
['HMASynthesizer'],
61-
['HSASynthesizer', 'IndependentSynthesizer', 'MultiTableUniformSynthesizer'],
62-
]
6357

6458

6559
def get_result_folder_name(date_str):
@@ -97,13 +91,13 @@ def post_slack_message(channel, text):
9791
client.chat_postMessage(channel=channel, text=text)
9892

9993

100-
def post_benchmark_launch_message(date_str, compute_service='AWS'):
94+
def post_benchmark_launch_message(date_str):
10195
"""Post a message to the SDV Alerts Slack channel when the benchmark is launched."""
10296
channel = SLACK_CHANNEL
10397
folder_name = get_result_folder_name(date_str)
10498
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
10599
url_link = get_s3_console_link(bucket, f'{prefix}{folder_name}/')
106-
body = f'🏃 SDGym benchmark has been launched on {compute_service}! '
100+
body = '🏃 SDGym benchmark has been launched! EC2 Instances are running. '
107101
body += f'Intermediate results can be found <{url_link}|here>.\n'
108102
post_slack_message(channel, body)
109103

tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def rmdir(c, path):
203203
pass
204204

205205
@task
206-
def run_sdgym_benchmark(c, modality='single_table'):
206+
def run_sdgym_benchmark(c):
207207
"""Run the SDGym benchmark."""
208-
c.run(f'python sdgym/run_benchmark/run_benchmark.py --modality {modality}')
208+
c.run('python sdgym/run_benchmark/run_benchmark.py')
209209

210210
@task
211211
def upload_benchmark_results(c):

tests/unit/run_benchmark/test_run_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from botocore.exceptions import ClientError
66

77
from sdgym.run_benchmark.run_benchmark import append_benchmark_run, main
8-
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT_SINGLE_TABLE
8+
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT
99

1010

1111
@patch('sdgym.run_benchmark.run_benchmark.get_s3_client')
@@ -122,7 +122,7 @@ def test_main(
122122
mock_getenv.assert_any_call('AWS_ACCESS_KEY_ID')
123123
mock_getenv.assert_any_call('AWS_SECRET_ACCESS_KEY')
124124
expected_calls = []
125-
for synthesizer in SYNTHESIZERS_SPLIT_SINGLE_TABLE:
125+
for synthesizer in SYNTHESIZERS_SPLIT:
126126
expected_calls.append(
127127
call(
128128
output_destination=OUTPUT_DESTINATION_AWS,

0 commit comments

Comments
 (0)