Skip to content

Commit 6a0dc48

Browse files
committed
def sclack 1
1 parent 62a1951 commit 6a0dc48

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

sdgym/run_benchmark/run_benchmark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from sdgym.run_benchmark.utils import (
1111
KEY_DATE_FILE,
1212
OUTPUT_DESTINATION_AWS,
13-
SYNTHESIZERS,
13+
SYNTHESIZERS_SPLIT,
1414
get_result_folder_name,
15+
post_benchmark_launch_message,
1516
)
1617
from sdgym.s3 import get_s3_client, parse_s3_path
1718

@@ -45,16 +46,18 @@ def main():
4546
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
4647
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
4748
date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
48-
for synthesizer in SYNTHESIZERS:
49+
for synthesizer_group in SYNTHESIZERS_SPLIT[:2]:
4950
benchmark_single_table_aws(
5051
output_destination=OUTPUT_DESTINATION_AWS,
52+
dataset=['expedia_hotel_logs', 'fake_companies'],
5153
aws_access_key_id=aws_access_key_id,
5254
aws_secret_access_key=aws_secret_access_key,
53-
synthesizers=[synthesizer],
55+
synthesizers=synthesizer_group,
5456
compute_privacy_score=False,
5557
)
5658

5759
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str)
60+
post_benchmark_launch_message()
5861

5962

6063
if __name__ == '__main__':

sdgym/run_benchmark/utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
"""Utils file for the run_benchmark module."""
22

3+
import os
34
from datetime import datetime
45

5-
from sdgym.benchmark import SDV_SINGLE_TABLE_SYNTHESIZERS
6+
from slack_sdk import WebClient
67

78
OUTPUT_DESTINATION_AWS = 's3://sdgym-benchmark/Debug/Issue_425/'
89
UPLOAD_DESTINATION_AWS = 's3://sdgym-benchmark/Debug/Issue_425/'
910
DEBUG_SLACK_CHANNEL = 'sdv-alerts-debug'
1011
SLACK_CHANNEL = 'sdv-alerts'
1112
KEY_DATE_FILE = '_BENCHMARK_DATES.json'
12-
SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS
13+
14+
# The synthesizers inside the same list will be run by the same ec2 instance
15+
SYNTHESIZERS_SPLIT = [
16+
['UniformSynthesizer', 'ColumnSynthesizer', 'GaussianCopulaSynthesizer'],
17+
['TVAESynthesizer'],
18+
['CopulaGANSynthesizer'],
19+
['CTGANSynthesizer'],
20+
['RealTabFormerSynthesizer'],
21+
]
1322

1423

1524
def get_result_folder_name(date_str):
@@ -20,3 +29,41 @@ def get_result_folder_name(date_str):
2029
raise ValueError(f'Invalid date format: {date_str}. Expected YYYY-MM-DD.')
2130

2231
return f'SDGym_results_{date.month:02d}_{date.day:02d}_{date.year}'
32+
33+
34+
def _get_slack_client():
35+
"""Create an authenticated Slack client.
36+
37+
Returns:
38+
WebClient:
39+
An authenticated Slack WebClient instance.
40+
"""
41+
token = os.getenv('SLACK_TOKEN')
42+
client = WebClient(token=token)
43+
return client
44+
45+
46+
def post_slack_message(channel, text):
47+
"""Post a message to a Slack channel."""
48+
client = _get_slack_client()
49+
client.chat_postMessage(channel=channel, text=text)
50+
51+
52+
def post_benchmark_launch_message():
53+
"""Post a message to the SDV Alerts Slack channel when the benchmark is launched."""
54+
channel = DEBUG_SLACK_CHANNEL
55+
body = 'SDGym benchmark has been launched! Results will be available soon.'
56+
post_slack_message(channel, body)
57+
58+
59+
def post_run_summary(folder_name):
60+
"""Post run summary to sdv-alerts slack channel."""
61+
channel = DEBUG_SLACK_CHANNEL
62+
body = ''
63+
body += f'SDGym benchmark results for {folder_name} are available!\n'
64+
body += (
65+
f'Check the results <{OUTPUT_DESTINATION_AWS}{folder_name}/{folder_name}_summary'
66+
'.csv|here>.\n'
67+
)
68+
69+
post_slack_message(channel, body)

tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ def rmdir(c, path):
206206
@task
207207
def run_sdgym_benchmark(c):
208208
"""Run the SDGym benchmark."""
209-
c.run('python sdgym/_run_benchmark/run_benchmark.py')
209+
c.run('python sdgym/run_benchmark/run_benchmark.py')
210210

211211
@task
212212
def upload_benchmark_results(c, date=None):
213213
"""Upload the benchmark results to S3."""
214-
c.run(f'python sdgym/_run_benchmark/upload_benchmark_results.py {date}')
214+
c.run(f'python sdgym/run_benchmark/upload_benchmark_results.py {date}')

0 commit comments

Comments
 (0)