Skip to content

Commit 9bab838

Browse files
committed
add unit test
1 parent 72cd956 commit 9bab838

File tree

8 files changed

+110
-20
lines changed

8 files changed

+110
-20
lines changed

.github/workflows/run_benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
python -m pip install invoke
2626
python -m pip install -e .[dev]
2727
28-
- name: SDGym Benchmark
28+
- name: Run SDGym Benchmark
2929
env:
3030
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
3131
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}

.github/workflows/upload_benchmark_results.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
name: Upload SDGym Benchmark results
22

33
on:
4-
push:
5-
branches:
6-
- issue-425-workflow-sdgym
74
workflow_run:
85
workflows: ["Run SDGym Benchmark"]
96
types:
@@ -29,7 +26,7 @@ jobs:
2926
python -m pip install invoke
3027
python -m pip install -e .[dev]
3128
32-
- name: SDGym Benchmark
29+
- name: Upload SDGym Benchmark
3330
env:
3431
PYDRIVE_CREDENTIALS: ${{ secrets.PYDRIVE_CREDENTIALS }}
3532
SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}

sdgym/_run_benchmark/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Folder for the SDGym benchmark module."""
22

3+
from sdgym.benchmark import SDV_SINGLE_TABLE_SYNTHESIZERS
4+
35
OUTPUT_DESTINATION_AWS = 's3://sdgym-benchmark/Debug/Issue_425/'
46
UPLOAD_DESTINATION_AWS = 's3://sdgym-benchmark/Debug/Issue_425/'
57
DEBUG_SLACK_CHANNEL = 'sdv-alerts-debug'
68
SLACK_CHANNEL = 'sdv-alerts'
9+
SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS

sdgym/_run_benchmark/run_benchmark.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import os
33
from datetime import datetime, timezone
44

5+
from botocore.exceptions import ClientError
6+
57
import sdgym._run_benchmark as run_benchmark
68
from sdgym._run_benchmark._utils import get_run_name
79
from sdgym.benchmark import benchmark_single_table_aws
810
from sdgym.s3 import get_s3_client, parse_s3_path
911

10-
datasets = ['expedia_hotel_logs', 'fake_companies'] # DEFAULT_DATASETS
11-
1212

1313
def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
1414
s3_client = get_s3_client(
@@ -21,7 +21,7 @@ def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
2121
object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{key}')
2222
body = object['Body'].read().decode('utf-8')
2323
data = json.loads(body)
24-
except s3_client.exceptions.ClientError as e:
24+
except ClientError as e:
2525
if e.response['Error']['Code'] == 'NoSuchKey':
2626
data = {'runs': []}
2727
else:
@@ -36,13 +36,12 @@ def main():
3636
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
3737
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
3838
date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
39-
for synthesizer in ['GaussianCopulaSynthesizer', 'TVAESynthesizer']:
39+
for synthesizer in run_benchmark.SYNTHESIZERS:
4040
benchmark_single_table_aws(
4141
output_destination=run_benchmark.OUTPUT_DESTINATION_AWS,
4242
aws_access_key_id=aws_access_key_id,
4343
aws_secret_access_key=aws_secret_access_key,
4444
synthesizers=[synthesizer],
45-
sdv_datasets=datasets,
4645
compute_privacy_score=False,
4746
)
4847

sdgym/_run_benchmark/upload_benchmark_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def main():
8282
aws_access_key_id, aws_secret_access_key
8383
)
8484
if upload_already_done(s3_client, bucket, prefix, run_name):
85-
LOGGER.info('Benchmark results have already been uploaded. Exiting.')
85+
LOGGER.warning('Benchmark results have already been uploaded. Exiting.')
8686
sys.exit(0)
8787

8888
upload_results(aws_access_key_id, aws_secret_access_key, run_name, s3_client, bucket, prefix)

tasks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,7 @@ def run_sdgym_benchmark(c):
208208
"""Run the SDGym benchmark."""
209209
c.run('python sdgym/_run_benchmark/run_benchmark.py')
210210

211-
@task(help={"date": "Benchmark date in YYYY-MM-DD format (default: today with day=01)"})
211+
@task
212212
def upload_benchmark_results(c, date=None):
213213
"""Upload the benchmark results to S3."""
214-
date_arg = f"--date {date}" if date else ""
215-
c.run(f'python sdgym/_run_benchmark/upload_benchmark_results.py {date_arg}')
214+
c.run(f'python sdgym/_run_benchmark/upload_benchmark_results.py {date}')

tests/unit/_run_benchmark/test_run_benchmark.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,101 @@
1+
import json
12
from datetime import datetime, timezone
2-
from unittest.mock import call, patch
3+
from unittest.mock import Mock, call, patch
34

4-
from sdgym._run_benchmark import OUTPUT_DESTINATION_AWS
5-
from sdgym._run_benchmark.run_benchmark import main
5+
from botocore.exceptions import ClientError
6+
7+
from sdgym._run_benchmark import OUTPUT_DESTINATION_AWS, SYNTHESIZERS
8+
from sdgym._run_benchmark.run_benchmark import append_benchmark_run, main
9+
10+
11+
@patch('sdgym._run_benchmark.run_benchmark.get_s3_client')
12+
@patch('sdgym._run_benchmark.run_benchmark.parse_s3_path')
13+
@patch('sdgym._run_benchmark.run_benchmark.get_run_name')
14+
def test_append_benchmark_run(mock_get_run_name, mock_parse_s3_path, mock_get_s3_client):
15+
"""Test the `append_benchmark_run` method."""
16+
# Setup
17+
aws_access_key_id = 'my_access_key'
18+
aws_secret_access_key = 'my_secret_key'
19+
date = '2023-10-01'
20+
mock_get_run_name.return_value = 'SDGym_results_10_01_2023'
21+
mock_parse_s3_path.return_value = ('my-bucket', 'my-prefix/')
22+
mock_s3_client = Mock()
23+
benchmark_date = {
24+
'runs': [
25+
{'date': '2023-09-30', 'run_name': 'SDGym_results_09_30_2023'},
26+
]
27+
}
28+
mock_get_s3_client.return_value = mock_s3_client
29+
mock_s3_client.get_object.return_value = {
30+
'Body': Mock(read=lambda: json.dumps(benchmark_date).encode('utf-8'))
31+
}
32+
expected_data = {
33+
'runs': [
34+
{'date': '2023-09-30', 'run_name': 'SDGym_results_09_30_2023'},
35+
{'date': date, 'run_name': 'SDGym_results_10_01_2023'},
36+
]
37+
}
38+
39+
# Run
40+
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date)
41+
42+
# Assert
43+
mock_get_s3_client.assert_called_once_with(
44+
aws_access_key_id=aws_access_key_id,
45+
aws_secret_access_key=aws_secret_access_key,
46+
)
47+
mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
48+
mock_get_run_name.assert_called_once_with(date)
49+
mock_s3_client.get_object.assert_called_once_with(
50+
Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json'
51+
)
52+
mock_s3_client.put_object.assert_called_once_with(
53+
Bucket='my-bucket',
54+
Key='my-prefix/_BENCHMARK_DATES.json',
55+
Body=json.dumps(expected_data).encode('utf-8'),
56+
)
57+
58+
59+
@patch('sdgym._run_benchmark.run_benchmark.get_s3_client')
60+
@patch('sdgym._run_benchmark.run_benchmark.parse_s3_path')
61+
@patch('sdgym._run_benchmark.run_benchmark.get_run_name')
62+
def test_append_benchmark_run_new_file(mock_get_run_name, mock_parse_s3_path, mock_get_s3_client):
63+
"""Test the `append_benchmark_run` with a new file."""
64+
# Setup
65+
aws_access_key_id = 'my_access_key'
66+
aws_secret_access_key = 'my_secret_key'
67+
date = '2023-10-01'
68+
mock_get_run_name.return_value = 'SDGym_results_10_01_2023'
69+
mock_parse_s3_path.return_value = ('my-bucket', 'my-prefix/')
70+
mock_s3_client = Mock()
71+
mock_get_s3_client.return_value = mock_s3_client
72+
mock_s3_client.get_object.side_effect = ClientError(
73+
{'Error': {'Code': 'NoSuchKey'}}, 'GetObject'
74+
)
75+
expected_data = {
76+
'runs': [
77+
{'date': date, 'run_name': 'SDGym_results_10_01_2023'},
78+
]
79+
}
80+
81+
# Run
82+
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date)
83+
84+
# Assert
85+
mock_get_s3_client.assert_called_once_with(
86+
aws_access_key_id=aws_access_key_id,
87+
aws_secret_access_key=aws_secret_access_key,
88+
)
89+
mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS)
90+
mock_get_run_name.assert_called_once_with(date)
91+
mock_s3_client.get_object.assert_called_once_with(
92+
Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json'
93+
)
94+
mock_s3_client.put_object.assert_called_once_with(
95+
Bucket='my-bucket',
96+
Key='my-prefix/_BENCHMARK_DATES.json',
97+
Body=json.dumps(expected_data).encode('utf-8'),
98+
)
699

7100

8101
@patch('sdgym._run_benchmark.run_benchmark.benchmark_single_table_aws')
@@ -21,14 +114,13 @@ def test_main(mock_append_benchmark_run, mock_getenv, mock_benchmark_single_tabl
21114
mock_getenv.assert_any_call('AWS_ACCESS_KEY_ID')
22115
mock_getenv.assert_any_call('AWS_SECRET_ACCESS_KEY')
23116
expected_calls = []
24-
for synthesizer in ['GaussianCopulaSynthesizer', 'TVAESynthesizer']:
117+
for synthesizer in SYNTHESIZERS:
25118
expected_calls.append(
26119
call(
27120
output_destination=OUTPUT_DESTINATION_AWS,
28121
aws_access_key_id='my_access_key',
29122
aws_secret_access_key='my_secret_key',
30123
synthesizers=[synthesizer],
31-
sdv_datasets=['expedia_hotel_logs', 'fake_companies'],
32124
compute_privacy_score=False,
33125
)
34126
)

tests/unit/_run_benchmark/test_upload_benchmark_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_main_already_upload(
206206

207207
# Assert
208208
mock_get_run_name_and_s3_vars.assert_called_once_with('my_access_key', 'my_secret_key')
209-
mock_logger.info.assert_called_once_with(expected_log_message)
209+
mock_logger.warning.assert_called_once_with(expected_log_message)
210210
mock_upload_results.assert_not_called()
211211

212212

0 commit comments

Comments
 (0)