Skip to content

Commit f952ea0

Browse files
committed
test 7
1 parent ee96de1 commit f952ea0

File tree

8 files changed

+118
-77
lines changed

8 files changed

+118
-77
lines changed

.github/workflows/run_benchmark_multi_table.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ jobs:
2424

2525
- name: Install dependencies
2626
env:
27-
username: ${{ secrets.SDV_ENTERPRISE_USERNAME }}
28-
license_key: ${{ secrets.SDV_ENTERPRISE_LICENSE_KEY }}
27+
USERNAME: ${{ secrets.SDV_ENTERPRISE_USERNAME }}
28+
LICENSE_KEY: ${{ secrets.SDV_ENTERPRISE_LICENSE_KEY }}
2929
run: |
3030
python -m pip install --upgrade pip
31-
python -m pip install bundle-xsynthesizers --index-url https://${username}:${license_key}@pypi.datacebo.com
32-
pip install --no-cache-dir "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-516-add-workflows"
31+
python -m pip install bundle-xsynthesizers --index-url "https://${USERNAME}:${LICENSE_KEY}@pypi.datacebo.com"
32+
python -m pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-516-add-workflows"
3333
3434
- name: Run SDGym Benchmark
3535
env:

sdgym/_benchmark/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def _benchmark_single_table_compute_gcp(
409409
limit_dataset_size=False,
410410
compute_quality_score=True,
411411
compute_diagnostic_score=True,
412-
compute_privacy_score=True,
412+
compute_privacy_score=False,
413413
sdmetrics=None,
414414
timeout=None,
415415
):

sdgym/benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,8 +1572,9 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client):
15721572
bucket_name = parsed_url.netloc
15731573
path = parsed_url.path.lstrip('/') if parsed_url.path else ''
15741574
filename = os.path.basename(job_args_list[0][-1]['metainfo'])
1575+
modality = job_args_list[0][MODALITY_IDX]
15751576
metainfo = os.path.splitext(filename)[0]
1576-
job_args_key = f'job_args_list_{metainfo}.pkl.gz'
1577+
job_args_key = f'job_args_list_{modality}_{metainfo}.pkl.gz'
15771578
job_args_key = f'{path}{job_args_key}' if path else job_args_key
15781579

15791580
serialized_data = cloudpickle.dumps(job_args_list)

sdgym/run_benchmark/run_benchmark.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
from botocore.exceptions import ClientError
99

10-
from sdgym._benchmark.benchmark import _benchmark_multi_table_compute_gcp
11-
from sdgym.benchmark import benchmark_single_table_aws
10+
from sdgym._benchmark.benchmark import (
11+
_benchmark_multi_table_compute_gcp,
12+
_benchmark_single_table_compute_gcp,
13+
)
1214
from sdgym.run_benchmark.utils import (
1315
KEY_DATE_FILE,
1416
OUTPUT_DESTINATION_AWS,
@@ -19,6 +21,17 @@
1921
)
2022
from sdgym.s3 import get_s3_client, parse_s3_path
2123

24+
MODALITY_TO_SETUP = {
25+
'single_table': {
26+
'method': _benchmark_single_table_compute_gcp,
27+
'synthesizers_split': SYNTHESIZERS_SPLIT_SINGLE_TABLE,
28+
},
29+
'multi_table': {
30+
'method': _benchmark_multi_table_compute_gcp,
31+
'synthesizers_split': SYNTHESIZERS_SPLIT_MULTI_TABLE,
32+
},
33+
}
34+
2235

2336
def append_benchmark_run(
2437
aws_access_key_id, aws_secret_access_key, date_str, modality='single_table'
@@ -42,7 +55,9 @@ def append_benchmark_run(
4255
data['runs'].append({'date': date_str, 'folder_name': get_result_folder_name(date_str)})
4356
data['runs'] = sorted(data['runs'], key=lambda x: x['date'])
4457
s3_client.put_object(
45-
Bucket=bucket, Key=f'{prefix}{KEY_DATE_FILE}', Body=json.dumps(data).encode('utf-8')
58+
Bucket=bucket,
59+
Key=f'{prefix}{modality}/{KEY_DATE_FILE}',
60+
Body=json.dumps(data).encode('utf-8'),
4661
)
4762

4863

@@ -63,35 +78,17 @@ def main():
6378
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
6479
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
6580
date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
66-
67-
if args.modality == 'single_table':
68-
for synthesizer_group in SYNTHESIZERS_SPLIT_SINGLE_TABLE:
69-
benchmark_single_table_aws(
70-
output_destination=OUTPUT_DESTINATION_AWS,
71-
aws_access_key_id=aws_access_key_id,
72-
aws_secret_access_key=aws_secret_access_key,
73-
synthesizers=synthesizer_group,
74-
compute_privacy_score=False,
75-
timeout=345600, # 4 days
76-
)
77-
78-
append_benchmark_run(
79-
aws_access_key_id, aws_secret_access_key, date_str, modality='single_table'
80-
)
81-
82-
else:
83-
for synthesizer_group in SYNTHESIZERS_SPLIT_MULTI_TABLE:
84-
_benchmark_multi_table_compute_gcp(
85-
output_destination='s3://sdgym-benchmark/Debug/GCP_Github/',
86-
credential_filepath=os.getenv('CREDENTIALS_FILEPATH'),
87-
synthesizers=synthesizer_group,
88-
timeout=345600, # 4 days
89-
)
90-
append_benchmark_run(
91-
aws_access_key_id, aws_secret_access_key, date_str, modality='multi_table'
81+
modality = args.modality
82+
for synthesizer_group in MODALITY_TO_SETUP[modality]['synthesizers_split']:
83+
MODALITY_TO_SETUP[modality]['method'](
84+
output_destination=OUTPUT_DESTINATION_AWS,
85+
credential_filepath=os.getenv('CREDENTIALS_FILEPATH'),
86+
synthesizers=synthesizer_group,
87+
timeout=345600, # 4 days
9288
)
9389

94-
post_benchmark_launch_message(date_str, compute_service='GCP')
90+
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str, modality=modality)
91+
post_benchmark_launch_message(date_str, compute_service='GCP', modality=modality)
9592

9693

9794
if __name__ == '__main__':

sdgym/run_benchmark/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
from sdgym.s3 import parse_s3_path
1111

12-
OUTPUT_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/'
12+
OUTPUT_DESTINATION_AWS = (
13+
's3://sdgym-benchmark/Debug/GCP_Github/' # 's3://sdgym-benchmark/Benchmarks/'
14+
)
1315
UPLOAD_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/'
1416
DEBUG_SLACK_CHANNEL = 'sdv-alerts-debug'
1517
SLACK_CHANNEL = 'sdv-alerts'
@@ -95,22 +97,22 @@ def post_slack_message(channel, text):
9597
client.chat_postMessage(channel=channel, text=text)
9698

9799

98-
def post_benchmark_launch_message(date_str, compute_service='AWS'):
100+
def post_benchmark_launch_message(date_str, compute_service='AWS', modality='single_table'):
99101
"""Post a message to the SDV Alerts Slack channel when the benchmark is launched."""
100102
channel = DEBUG_SLACK_CHANNEL
101103
folder_name = get_result_folder_name(date_str)
102104
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
103-
url_link = get_s3_console_link(bucket, f'{prefix}{folder_name}/')
105+
url_link = get_s3_console_link(bucket, f'{prefix}{modality}/{folder_name}/')
104106
body = f'🏃 SDGym benchmark has been launched on {compute_service}! '
105107
body += f'Intermediate results can be found <{url_link}|here>.\n'
106108
post_slack_message(channel, body)
107109

108110

109-
def post_benchmark_uploaded_message(folder_name, commit_url=None):
111+
def post_benchmark_uploaded_message(folder_name, commit_url=None, modality='single_table'):
110112
"""Post benchmark uploaded message to sdv-alerts slack channel."""
111113
channel = SLACK_CHANNEL
112114
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
113-
url_link = get_s3_console_link(bucket, quote_plus(f'{prefix}SDGym Monthly Run.xlsx'))
115+
url_link = get_s3_console_link(bucket, quote_plus(f'{prefix}{modality}/SDGym Monthly Run.xlsx'))
114116
body = (
115117
f'🤸🏻‍♀️ SDGym benchmark results for *{folder_name}* are available! 🏋️‍♀️\n'
116118
f'Check the results:\n'

tests/unit/_benchmark/test_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def test_benchmark_single_table_compute_gcp(mock_benchmark_compute):
415415
limit_dataset_size=limit_dataset_size,
416416
compute_quality_score=compute_quality_score,
417417
compute_diagnostic_score=compute_diagnostic_score,
418-
compute_privacy_score=True,
418+
compute_privacy_score=False,
419419
sdmetrics=sdmetrics,
420420
timeout=timeout,
421421
modality='single_table',
@@ -446,7 +446,7 @@ def test_benchmark_single_table_compute_gcp_defaults(mock_benchmark_compute):
446446
limit_dataset_size=False,
447447
compute_quality_score=True,
448448
compute_diagnostic_score=True,
449-
compute_privacy_score=True,
449+
compute_privacy_score=False,
450450
sdmetrics=None,
451451
timeout=None,
452452
modality='single_table',

tests/unit/run_benchmark/test_run_benchmark.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@
22
from datetime import datetime, timezone
33
from unittest.mock import Mock, call, patch
44

5+
import pytest
56
from botocore.exceptions import ClientError
67

7-
from sdgym.run_benchmark.run_benchmark import append_benchmark_run, main
8-
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT_SINGLE_TABLE
8+
from sdgym.run_benchmark.run_benchmark import (
9+
append_benchmark_run,
10+
main,
11+
)
12+
from sdgym.run_benchmark.utils import (
13+
OUTPUT_DESTINATION_AWS,
14+
SYNTHESIZERS_SPLIT_MULTI_TABLE,
15+
SYNTHESIZERS_SPLIT_SINGLE_TABLE,
16+
)
917

1018

1119
@patch('sdgym.run_benchmark.run_benchmark.get_s3_client')
@@ -51,7 +59,7 @@ def test_append_benchmark_run(mock_get_result_folder_name, mock_parse_s3_path, m
5159
)
5260
mock_s3_client.put_object.assert_called_once_with(
5361
Bucket='my-bucket',
54-
Key='my-prefix/_BENCHMARK_DATES.json',
62+
Key='my-prefix/single_table/_BENCHMARK_DATES.json',
5563
Body=json.dumps(expected_data).encode('utf-8'),
5664
)
5765

@@ -91,53 +99,84 @@ def test_append_benchmark_run_new_file(
9199
mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
92100
mock_get_result_folder_name.assert_called_once_with(date)
93101
mock_s3_client.get_object.assert_called_once_with(
94-
Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json'
102+
Bucket='my-bucket', Key='my-prefix/single_table/_BENCHMARK_DATES.json'
95103
)
96104
mock_s3_client.put_object.assert_called_once_with(
97105
Bucket='my-bucket',
98-
Key='my-prefix/_BENCHMARK_DATES.json',
106+
Key='my-prefix/single_table/_BENCHMARK_DATES.json',
99107
Body=json.dumps(expected_data).encode('utf-8'),
100108
)
101109

102110

103-
@patch('sdgym.run_benchmark.run_benchmark.benchmark_single_table_aws')
104-
@patch('sdgym.run_benchmark.run_benchmark.os.getenv')
105-
@patch('sdgym.run_benchmark.run_benchmark.append_benchmark_run')
111+
@pytest.mark.parametrize(
112+
'modality,synthesizer_split',
113+
[
114+
('single_table', SYNTHESIZERS_SPLIT_SINGLE_TABLE),
115+
('multi_table', SYNTHESIZERS_SPLIT_MULTI_TABLE),
116+
],
117+
)
106118
@patch('sdgym.run_benchmark.run_benchmark.post_benchmark_launch_message')
119+
@patch('sdgym.run_benchmark.run_benchmark.append_benchmark_run')
120+
@patch('sdgym.run_benchmark.run_benchmark.os.getenv')
121+
@patch('sdgym.run_benchmark.run_benchmark._parse_args')
122+
@patch.dict(
123+
'sdgym.run_benchmark.run_benchmark.MODALITY_TO_SETUP',
124+
values={
125+
'single_table': {
126+
'method': Mock(name='mock_single_method'),
127+
'synthesizers_split': [],
128+
},
129+
'multi_table': {
130+
'method': Mock(name='mock_multi_method'),
131+
'synthesizers_split': [],
132+
},
133+
},
134+
clear=True,
135+
)
107136
def test_main(
108-
mock_post_benchmark_launch_message,
109-
mock_append_benchmark_run,
137+
mock_parse_args,
110138
mock_getenv,
111-
mock_benchmark_single_table_aws,
139+
mock_append_benchmark_run,
140+
mock_post_benchmark_launch_message,
141+
modality,
142+
synthesizer_split,
112143
):
113-
"""Test the `main` method."""
144+
"""Test the `main` function with both single_table and multi_table modalities."""
114145
# Setup
115-
mock_getenv.side_effect = ['my_access_key', 'my_secret_key']
146+
from sdgym.run_benchmark.run_benchmark import MODALITY_TO_SETUP
147+
148+
mock_parse_args.return_value = Mock(modality=modality)
149+
mock_getenv.side_effect = lambda key: {
150+
'AWS_ACCESS_KEY_ID': 'my_access_key',
151+
'AWS_SECRET_ACCESS_KEY': 'my_secret_key',
152+
'CREDENTIALS_FILEPATH': '/path/to/creds.json',
153+
}.get(key)
154+
MODALITY_TO_SETUP[modality]['synthesizers_split'] = synthesizer_split
155+
mock_method = MODALITY_TO_SETUP[modality]['method']
116156
date = datetime.now(timezone.utc).strftime('%Y-%m-%d')
117157

118158
# Run
119159
main()
120160

121161
# Assert
122-
mock_getenv.assert_any_call('AWS_ACCESS_KEY_ID')
123-
mock_getenv.assert_any_call('AWS_SECRET_ACCESS_KEY')
124-
expected_calls = []
125-
for synthesizer in SYNTHESIZERS_SPLIT_SINGLE_TABLE:
126-
expected_calls.append(
127-
call(
128-
output_destination=OUTPUT_DESTINATION_AWS,
129-
aws_access_key_id='my_access_key',
130-
aws_secret_access_key='my_secret_key',
131-
synthesizers=synthesizer,
132-
compute_privacy_score=False,
133-
timeout=345600,
134-
)
162+
expected_calls = [
163+
call(
164+
output_destination=OUTPUT_DESTINATION_AWS,
165+
credential_filepath='/path/to/creds.json',
166+
synthesizers=group,
167+
timeout=345600,
135168
)
136-
137-
mock_benchmark_single_table_aws.assert_has_calls(expected_calls)
169+
for group in synthesizer_split
170+
]
171+
mock_method.assert_has_calls(expected_calls)
138172
mock_append_benchmark_run.assert_called_once_with(
139173
'my_access_key',
140174
'my_secret_key',
141175
date,
176+
modality=modality,
177+
)
178+
mock_post_benchmark_launch_message.assert_called_once_with(
179+
date,
180+
compute_service='GCP',
181+
modality=modality,
142182
)
143-
mock_post_benchmark_launch_message.assert_called_once_with(date)

tests/unit/run_benchmark/test_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_post_benchmark_launch_message(
9595
url = 'https://s3.console.aws.amazon.com/'
9696
mock_get_s3_console_link.return_value = url
9797
expected_body = (
98-
'🏃 SDGym benchmark has been launched! EC2 Instances are running. '
98+
'🏃 SDGym benchmark has been launched on AWS! '
9999
f'Intermediate results can be found <{url}|here>.\n'
100100
)
101101
# Run
@@ -104,8 +104,10 @@ def test_post_benchmark_launch_message(
104104
# Assert
105105
mock_get_result_folder_name.assert_called_once_with(date_str)
106106
mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
107-
mock_get_s3_console_link.assert_called_once_with('my-bucket', f'my-prefix/{folder_name}/')
108-
mock_post_slack_message.assert_called_once_with(SLACK_CHANNEL, expected_body)
107+
mock_get_s3_console_link.assert_called_once_with(
108+
'my-bucket', f'my-prefix/single_table/{folder_name}/'
109+
)
110+
mock_post_slack_message.assert_called_once_with('sdv-alerts-debug', expected_body)
109111

110112

111113
@patch('sdgym.run_benchmark.utils.post_slack_message')
@@ -136,7 +138,7 @@ def test_post_benchmark_uploaded_message(
136138
mock_post_slack_message.assert_called_once_with(SLACK_CHANNEL, expected_body)
137139
mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
138140
mock_get_s3_console_link.assert_called_once_with(
139-
'my-bucket', 'my-prefix%2FSDGym+Monthly+Run.xlsx'
141+
'my-bucket', 'my-prefix%2Fsingle_table%2FSDGym+Monthly+Run.xlsx'
140142
)
141143

142144

@@ -170,7 +172,7 @@ def test_post_benchmark_uploaded_message_with_commit(
170172
mock_post_slack_message.assert_called_once_with(SLACK_CHANNEL, expected_body)
171173
mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
172174
mock_get_s3_console_link.assert_called_once_with(
173-
'my-bucket', 'my-prefix%2FSDGym+Monthly+Run.xlsx'
175+
'my-bucket', 'my-prefix%2Fsingle_table%2FSDGym+Monthly+Run.xlsx'
174176
)
175177

176178

0 commit comments

Comments
 (0)