Skip to content

Commit 72cd956

Browse files
committed
improve datetime logic
1 parent 2849fa5 commit 72cd956

File tree

9 files changed

+85
-79
lines changed

9 files changed

+85
-79
lines changed

.github/workflows/run_benchmark.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ on:
99
- cron: '0 5 1 * *'
1010

1111
jobs:
12-
sdgym-benchmark:
12+
run-sdgym-benchmark:
1313
runs-on: ubuntu-latest
1414
steps:
1515
- uses: actions/checkout@v4
@@ -31,4 +31,4 @@ jobs:
3131
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
3232
AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }}
3333

34-
run: invoke sdgym-benchmark
34+
run: invoke run-sdgym-benchmark

.github/workflows/upload_benchmark_results.yml

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,11 @@ on:
99
types:
1010
- completed
1111
workflow_dispatch:
12-
inputs:
13-
date:
14-
description: 'Benchmark date (YYYY-MM-DD), defaults to the first of the current month'
15-
required: false
1612
schedule:
1713
- cron: '0 6 * * *'
1814

1915
jobs:
20-
sdgym-benchmark:
16+
upload-sdgym-benchmark:
2117
runs-on: ubuntu-latest
2218
steps:
2319
- uses: actions/checkout@v4
@@ -42,11 +38,4 @@ jobs:
4238
AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }}
4339

4440
run: |
45-
if [ -z "${{ github.event.inputs.date }}" ]; then
46-
BENCHMARK_DATE=$(date -u "+%Y-%m-01")
47-
else
48-
BENCHMARK_DATE="${{ github.event.inputs.date }}"
49-
fi
50-
51-
echo "Benchmark date: $BENCHMARK_DATE"
52-
invoke upload-benchmark-results --date "$BENCHMARK_DATE"
41+
invoke upload-benchmark-results

sdgym/_run_benchmark/_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from datetime import datetime
2+
3+
4+
def get_run_name(date_str):
5+
try:
6+
date = datetime.strptime(date_str, '%Y-%m-%d')
7+
except ValueError:
8+
raise ValueError(f'Invalid date format: {date_str}. Expected YYYY-MM-DD.')
9+
10+
return f'SDGym_results_{date.month:02d}_{date.day:02d}_{date.year}'

sdgym/_run_benchmark/run_benchmark.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,41 @@
1+
import json
12
import os
3+
from datetime import datetime, timezone
24

35
import sdgym._run_benchmark as run_benchmark
6+
from sdgym._run_benchmark._utils import get_run_name
47
from sdgym.benchmark import benchmark_single_table_aws
8+
from sdgym.s3 import get_s3_client, parse_s3_path
59

610
datasets = ['expedia_hotel_logs', 'fake_companies'] # DEFAULT_DATASETS
711

812

13+
def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str):
14+
s3_client = get_s3_client(
15+
aws_access_key_id=aws_access_key_id,
16+
aws_secret_access_key=aws_secret_access_key,
17+
)
18+
bucket, prefix = parse_s3_path(run_benchmark.OUTPUT_DESTINATION_AWS)
19+
key = '_BENCHMARK_DATES.json'
20+
try:
21+
object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{key}')
22+
body = object['Body'].read().decode('utf-8')
23+
data = json.loads(body)
24+
except s3_client.exceptions.ClientError as e:
25+
if e.response['Error']['Code'] == 'NoSuchKey':
26+
data = {'runs': []}
27+
else:
28+
raise RuntimeError(f'Failed to read {key} from S3: {e}')
29+
30+
data['runs'].append({'date': date_str, 'run_name': get_run_name(date_str)})
31+
data['runs'] = sorted(data['runs'], key=lambda x: x['date'])
32+
s3_client.put_object(Bucket=bucket, Key=f'{prefix}{key}', Body=json.dumps(data).encode('utf-8'))
33+
34+
935
def main():
1036
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
1137
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
38+
date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
1239
for synthesizer in ['GaussianCopulaSynthesizer', 'TVAESynthesizer']:
1340
benchmark_single_table_aws(
1441
output_destination=run_benchmark.OUTPUT_DESTINATION_AWS,
@@ -19,6 +46,8 @@ def main():
1946
compute_privacy_score=False,
2047
)
2148

49+
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str)
50+
2251

2352
if __name__ == '__main__':
2453
main()

sdgym/_run_benchmark/upload_benchmark_results.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import argparse
1+
import json
22
import logging
33
import os
44
import sys
5-
from datetime import datetime
65

76
import boto3
87
from botocore.exceptions import ClientError
@@ -15,19 +14,15 @@
1514
LOGGER = logging.getLogger(__name__)
1615

1716

18-
def parse_args():
19-
parser = argparse.ArgumentParser()
20-
parser.add_argument('--date', type=str, help='Benchmark date (YYYY-MM-DD)')
21-
return parser.parse_args()
22-
23-
24-
def get_run_name(date_str):
17+
def get_latest_run_from_file(s3_client, bucket, key):
2518
try:
26-
date = datetime.strptime(date_str, '%Y-%m-%d')
27-
except ValueError:
28-
raise ValueError(f'Invalid date format: {date_str}. Expected YYYY-MM-DD.')
29-
30-
return f'SDGym_results_{date.month:02d}_{date.day:02d}_{date.year}'
19+
object = s3_client.get_object(Bucket=bucket, Key=key)
20+
body = object['Body'].read().decode('utf-8')
21+
data = json.loads(body)
22+
latest = sorted(data['runs'], key=lambda x: x['date'])[-1]
23+
return latest['run_name']
24+
except s3_client.exceptions.ClientError as e:
25+
raise RuntimeError(f'Failed to read {key} from S3: {e}')
3126

3227

3328
def write_uploaded_marker(s3_client, bucket, prefix, run_name):
@@ -48,20 +43,14 @@ def upload_already_done(s3_client, bucket, prefix, run_name):
4843

4944

5045
def get_run_name_and_s3_vars(aws_access_key_id, aws_secret_access_key):
51-
args = parse_args()
52-
if args.date:
53-
date_str = args.date
54-
else:
55-
date_str = datetime.utcnow().replace(day=1).strftime('%Y-%m-%d')
56-
57-
run_name = get_run_name(date_str)
5846
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
5947
s3_client = boto3.client(
6048
's3',
6149
aws_access_key_id=aws_access_key_id,
6250
aws_secret_access_key=aws_secret_access_key,
6351
region_name=S3_REGION,
6452
)
53+
run_name = get_latest_run_from_file(s3_client, bucket, f'{prefix}_BENCHMARK_DATES.json')
6554

6655
return run_name, s3_client, bucket, prefix
6756

@@ -75,7 +64,7 @@ def upload_results(aws_access_key_id, aws_secret_access_key, run_name, s3_client
7564
result_writer = S3ResultsWriter(s3_client)
7665

7766
if not result_explorer.all_runs_complete(run_name):
78-
LOGGER.info(f'Run {run_name} is not complete yet. Exiting.')
67+
LOGGER.warning(f'Run {run_name} is not complete yet. Exiting.')
7968
sys.exit(0)
8069

8170
LOGGER.info(f'Run {run_name} is complete! Proceeding with summarization...')

tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def rmdir(c, path):
204204
pass
205205

206206
@task
207-
def sdgym_benchmark(c):
207+
def run_sdgym_benchmark(c):
208208
"""Run the SDGym benchmark."""
209209
c.run('python sdgym/_run_benchmark/run_benchmark.py')
210210

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import pytest
2+
3+
from sdgym._run_benchmark._utils import get_run_name
4+
5+
6+
def test_get_run_name():
7+
"""Test the `get_run_name` method."""
8+
# Setup
9+
expected_error_message = 'Invalid date format: invalid-date. Expected YYYY-MM-DD.'
10+
11+
# Run and Assert
12+
assert get_run_name('2023-10-01') == 'SDGym_results_10_01_2023'
13+
with pytest.raises(ValueError, match=expected_error_message):
14+
get_run_name('invalid-date')

tests/unit/_run_benchmark/test_run_benchmark.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime, timezone
12
from unittest.mock import call, patch
23

34
from sdgym._run_benchmark import OUTPUT_DESTINATION_AWS
@@ -6,10 +7,12 @@
67

78
@patch('sdgym._run_benchmark.run_benchmark.benchmark_single_table_aws')
89
@patch('sdgym._run_benchmark.run_benchmark.os.getenv')
9-
def test_main(mock_getenv, mock_benchmark_single_table_aws):
10+
@patch('sdgym._run_benchmark.run_benchmark.append_benchmark_run')
11+
def test_main(mock_append_benchmark_run, mock_getenv, mock_benchmark_single_table_aws):
1012
"""Test the `main` method."""
1113
# Setup
1214
mock_getenv.side_effect = ['my_access_key', 'my_secret_key']
15+
date = datetime.now(timezone.utc).strftime('%Y-%m-%d')
1316

1417
# Run
1518
main()
@@ -31,3 +34,8 @@ def test_main(mock_getenv, mock_benchmark_single_table_aws):
3134
)
3235

3336
mock_benchmark_single_table_aws.assert_has_calls(expected_calls)
37+
mock_append_benchmark_run.assert_called_once_with(
38+
'my_access_key',
39+
'my_secret_key',
40+
date,
41+
)

tests/unit/_run_benchmark/test_upload_benchmark_result.py

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,15 @@
44
from botocore.exceptions import ClientError
55

66
from sdgym._run_benchmark.upload_benchmark_results import (
7-
get_run_name,
87
get_run_name_and_s3_vars,
98
main,
10-
parse_args,
119
upload_already_done,
1210
upload_results,
1311
write_uploaded_marker,
1412
)
1513
from sdgym.s3 import S3_REGION
1614

1715

18-
@patch('sdgym._run_benchmark.upload_benchmark_results.argparse.ArgumentParser')
19-
def test_parse_args(mock_argparse):
20-
"""Test the `parse_args` method."""
21-
# Setup
22-
parser = mock_argparse.return_value
23-
parser.parse_args.return_value = Mock(date='01-07-2025')
24-
mock_argparse.return_value.add_argument = Mock()
25-
26-
# Run
27-
args = parse_args()
28-
29-
# Assert
30-
assert args.date == '01-07-2025'
31-
parser.add_argument.assert_called_once_with(
32-
'--date', type=str, help='Benchmark date (YYYY-MM-DD)'
33-
)
34-
parser.parse_args.assert_called_once()
35-
36-
37-
def test_get_run_name():
38-
"""Test the `get_run_name` method."""
39-
# Setup
40-
expected_error_message = 'Invalid date format: invalid-date. Expected YYYY-MM-DD.'
41-
42-
# Run and Assert
43-
assert get_run_name('2023-10-01') == 'SDGym_results_10_01_2023'
44-
with pytest.raises(ValueError, match=expected_error_message):
45-
get_run_name('invalid-date')
46-
47-
4816
def test_write_uploaded_marker():
4917
"""Test the `write_uploaded_marker` method."""
5018
# Setup
@@ -92,41 +60,40 @@ def test_upload_already_done():
9260
assert result_false is False
9361

9462

95-
@patch('sdgym._run_benchmark.upload_benchmark_results.get_run_name')
9663
@patch('sdgym._run_benchmark.upload_benchmark_results.boto3.client')
9764
@patch('sdgym._run_benchmark.upload_benchmark_results.parse_s3_path')
9865
@patch('sdgym._run_benchmark.upload_benchmark_results.OUTPUT_DESTINATION_AWS')
99-
@patch('sdgym._run_benchmark.upload_benchmark_results.parse_args')
66+
@patch('sdgym._run_benchmark.upload_benchmark_results.get_latest_run_from_file')
10067
def test_get_run_name_and_s3_vars(
101-
mock_parse_args,
68+
mock_get_latest_run_from_file,
10269
mock_output_destination_aws,
10370
mock_parse_s3_path,
10471
mock_boto_client,
105-
mock_get_run_name,
10672
):
10773
"""Test the `get_run_name_and_s3_vars` method."""
10874
# Setup
109-
mock_parse_args.return_value.date = '2023-10-01'
11075
aws_access_key_id = 'my_access_key'
11176
aws_secret_access_key = 'my_secret_key'
11277
expected_result = ('SDGym_results_10_01_2023', 's3_client', 'bucket', 'prefix')
113-
mock_get_run_name.return_value = 'SDGym_results_10_01_2023'
11478
mock_boto_client.return_value = 's3_client'
11579
mock_parse_s3_path.return_value = ('bucket', 'prefix')
80+
mock_get_latest_run_from_file.return_value = 'SDGym_results_10_01_2023'
11681

11782
# Run
11883
result = get_run_name_and_s3_vars(aws_access_key_id, aws_secret_access_key)
11984

12085
# Assert
12186
assert result == expected_result
122-
mock_get_run_name.assert_called_once_with('2023-10-01')
12387
mock_boto_client.assert_called_once_with(
12488
's3',
12589
aws_access_key_id=aws_access_key_id,
12690
aws_secret_access_key=aws_secret_access_key,
12791
region_name=S3_REGION,
12892
)
12993
mock_parse_s3_path.assert_called_once_with(mock_output_destination_aws)
94+
mock_get_latest_run_from_file.assert_called_once_with(
95+
's3_client', 'bucket', 'prefix_BENCHMARK_DATES.json'
96+
)
13097

13198

13299
@patch('sdgym._run_benchmark.upload_benchmark_results.SDGymResultsExplorer')
@@ -202,7 +169,7 @@ def test_upload_results_not_all_runs_complete(
202169
)
203170

204171
# Assert
205-
mock_logger.info.assert_called_once_with(f'Run {run_name} is not complete yet. Exiting.')
172+
mock_logger.warning.assert_called_once_with(f'Run {run_name} is not complete yet. Exiting.')
206173
mock_sdgym_results_explorer.assert_called_once_with(
207174
mock_output_destination_aws,
208175
aws_access_key_id=aws_access_key_id,

0 commit comments

Comments
 (0)