Skip to content

Commit 3eed4fe

Browse files
committed
add unit tests
1 parent 84e4c7c commit 3eed4fe

File tree

6 files changed

+308
-36
lines changed

6 files changed

+308
-36
lines changed

.github/workflows/upload_benchmark_results.yml

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,43 +47,48 @@ jobs:
4747
echo "Upload skipped. Exiting workflow."
4848
exit 0
4949
fi
50-
- name: Prepare summary file for PR
50+
- name: Prepare files for PR
5151
run: |
5252
mkdir pr-staging
53-
echo "Looking for: $GITHUB_LOCAL_RESULTS_DIR/${FOLDER_NAME}_summary.csv"
53+
echo "Looking for files in: $GITHUB_LOCAL_RESULTS_DIR"
5454
ls -l "$GITHUB_LOCAL_RESULTS_DIR"
55-
cp "$GITHUB_LOCAL_RESULTS_DIR/${FOLDER_NAME}_summary.csv" \
56-
"pr-staging/SDGym_summary_${FOLDER_NAME}.csv"
55+
for f in "$GITHUB_LOCAL_RESULTS_DIR"/${FOLDER_NAME}_*.csv; do
56+
base=$(basename "$f")
57+
cp "$f" "pr-staging/SDGym_${base}"
58+
done
59+
60+
echo "Files staged for PR:"
61+
ls -l pr-staging
5762
5863
- name: Checkout target repo (sdv-dev.github.io)
5964
run: |
6065
git clone https://github.com/sdv-dev/sdv-dev.github.io.git target-repo
6166
cd target-repo
6267
git checkout gatsby-home
63-
- name: Copy summary and create PR
68+
69+
- name: Copy results and create PR
6470
env:
71+
GH_TOKEN: ${{ secrets.GH_TOKEN }}
6572
FOLDER_NAME: ${{ env.FOLDER_NAME }}
6673
run: |
67-
cp pr-staging/SDGym_summary_${FOLDER_NAME}.csv target-repo/assets/
68-
74+
cp pr-staging/* target-repo/assets/
6975
cd target-repo
7076
git checkout -b sdgym-benchmark-upload-${FOLDER_NAME}
71-
git config --local user.name "${GITHUB_ACTOR}"
72-
git config --local user.email "${GITHUB_ACTOR_ID}+${GITHUB_ACTOR}@users.noreply.github.com"
77+
git config --local user.name "github-actions[bot]"
78+
git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
7379
7480
git add assets/
75-
git commit -m "Upload SDGym Benchmark Summary ($FOLDER_NAME)"
76-
77-
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/sdv-dev/sdv-dev.github.io.git
81+
git commit -m "Upload SDGym Benchmark Results ($FOLDER_NAME)"
82+
git remote set-url origin https://x-access-token:${GH_TOKEN}@github.com/sdv-dev/sdv-dev.github.io.git
7883
git push origin sdgym-benchmark-upload-${FOLDER_NAME}
7984
85+
# Create PR
8086
gh pr create \
8187
--repo sdv-dev/sdv-dev.github.io \
8288
--head sdgym-benchmark-upload-${FOLDER_NAME} \
8389
--base gatsby-home \
84-
--title "Upload SDGym Benchmark Summary ($FOLDER_NAME)" \
85-
--body "Automated benchmark summary upload" \
86-
--assignee "${{ github.actor }}"
90+
--title "Upload SDGym Benchmark Results ($FOLDER_NAME)" \
91+
--body "Automated SDGym benchmark results upload"
8792
8893
- name: Send Slack notification
8994
env:

sdgym/run_benchmark/upload_benchmark_results.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from botocore.exceptions import ClientError
1010

1111
from sdgym.result_writer import LocalResultsWriter, S3ResultsWriter
12-
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS
12+
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, get_df_to_plot
1313
from sdgym.s3 import S3_REGION, parse_s3_path
1414
from sdgym.sdgym_result_explorer.result_explorer import SDGymResultsExplorer
1515

@@ -86,7 +86,8 @@ def upload_results(
8686
env_file.write('SKIP_UPLOAD=false\n')
8787
env_file.write(f'FOLDER_NAME={folder_name}\n')
8888

89-
summary, _ = result_explorer.summarize(folder_name)
89+
summary, results = result_explorer.summarize(folder_name)
90+
df_to_plot = get_df_to_plot(results)
9091
result_writer.write_dataframe(
9192
summary, f'{OUTPUT_DESTINATION_AWS}{folder_name}/{folder_name}_summary.csv', index=True
9293
)
@@ -96,6 +97,9 @@ def upload_results(
9697
local_results_writer.write_dataframe(
9798
summary, f'{local_export_dir}/{folder_name}_summary.csv', index=True
9899
)
100+
local_results_writer.write_dataframe(
101+
df_to_plot, f'{local_export_dir}/{folder_name}_plot_data.csv', index=False
102+
)
99103

100104
write_uploaded_marker(s3_client, bucket, prefix, folder_name)
101105

sdgym/run_benchmark/utils.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from datetime import datetime
55

6+
import numpy as np
67
from slack_sdk import WebClient
78

89
from sdgym.s3 import parse_s3_path
@@ -12,6 +13,37 @@
1213
DEBUG_SLACK_CHANNEL = 'sdv-alerts-debug'
1314
SLACK_CHANNEL = 'sdv-alerts'
1415
KEY_DATE_FILE = '_BENCHMARK_DATES.json'
16+
PLOTLY_MARKERS = [
17+
'circle',
18+
'square',
19+
'diamond',
20+
'cross',
21+
'x',
22+
'triangle-up',
23+
'triangle-down',
24+
'triangle-left',
25+
'triangle-right',
26+
'pentagon',
27+
'hexagon',
28+
'hexagon2',
29+
'octagon',
30+
'star',
31+
'hexagram',
32+
'star-triangle-up',
33+
'star-triangle-down',
34+
'star-square',
35+
'star-diamond',
36+
'diamond-tall',
37+
'diamond-wide',
38+
'hourglass',
39+
'bowtie',
40+
'circle-cross',
41+
'circle-x',
42+
'square-cross',
43+
'square-x',
44+
'diamond-cross',
45+
'diamond-x',
46+
]
1547

1648
# The synthesizers inside the same list will be run by the same ec2 instance
1749
SYNTHESIZERS_SPLIT = [
@@ -68,7 +100,7 @@ def post_benchmark_launch_message(date_str):
68100
bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS)
69101
url_link = get_s3_console_link(bucket, f'{prefix}{folder_name}/')
70102
body = '🏃 SDGym benchmark has been launched! EC2 Instances are running. '
71-
body += f'Intermediate results can be found <{url_link} |here>.\n'
103+
body += f'Intermediate results can be found <{url_link}|here>.\n'
72104
post_slack_message(channel, body)
73105

74106

@@ -85,3 +117,43 @@ def post_benchmark_uploaded_message(folder_name, pr_url=None):
85117
body += f'Waiting on merging this PR to update GitHub directory: <{pr_url}|PR Link>\n'
86118

87119
post_slack_message(channel, body)
120+
121+
122+
def get_df_to_plot(benchmark_result):
123+
"""Get the data to plot from the benchmark result.
124+
125+
Args:
126+
benchmark_result (DataFrame): The benchmark result DataFrame.
127+
128+
Returns:
129+
DataFrame: The data to plot.
130+
"""
131+
df_to_plot = benchmark_result.copy()
132+
df_to_plot['total_time'] = df_to_plot['Train_Time'] + df_to_plot['Sample_Time']
133+
df_to_plot['Aggregated_Time'] = df_to_plot.groupby('Synthesizer')['total_time'].transform('sum')
134+
df_to_plot = (
135+
df_to_plot.groupby('Synthesizer')[['Aggregated_Time', 'Quality_Score']].mean().reset_index()
136+
)
137+
df_to_plot['Log10 Aggregated_Time'] = df_to_plot['Aggregated_Time'].apply(
138+
lambda x: np.log10(x) if x > 0 else 0
139+
)
140+
df_to_plot = df_to_plot.sort_values(
141+
['Aggregated_Time', 'Quality_Score'], ascending=[True, False]
142+
)
143+
df_to_plot['Cumulative Quality Score'] = df_to_plot['Quality_Score'].cummax()
144+
pareto_points = df_to_plot.loc[
145+
df_to_plot['Quality_Score'] == df_to_plot['Cumulative Quality Score']
146+
]
147+
df_to_plot['Pareto'] = df_to_plot.index.isin(pareto_points.index)
148+
df_to_plot['Color'] = df_to_plot['Pareto'].apply(lambda x: '#01E0C9' if x else '#03AFF1')
149+
df_to_plot['Synthesizer'] = df_to_plot['Synthesizer'].str.replace(
150+
'Synthesizer', '', regex=False
151+
)
152+
153+
synthesizers = df_to_plot['Synthesizer'].unique()
154+
marker_map = {
155+
synth: PLOTLY_MARKERS[i % len(PLOTLY_MARKERS)] for i, synth in enumerate(synthesizers)
156+
}
157+
df_to_plot['Marker'] = df_to_plot['Synthesizer'].map(marker_map)
158+
159+
return df_to_plot.drop(columns=['Cumulative Quality Score']).reset_index(drop=True)

tests/unit/run_benchmark/test__utils.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

tests/unit/run_benchmark/test_upload_benchmark_result.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock, patch
1+
from unittest.mock import Mock, call, patch
22

33
import pytest
44
from botocore.exceptions import ClientError
@@ -103,7 +103,9 @@ def test_get_result_folder_name_and_s3_vars(
103103
@patch('sdgym.run_benchmark.upload_benchmark_results.OUTPUT_DESTINATION_AWS')
104104
@patch('sdgym.run_benchmark.upload_benchmark_results.LocalResultsWriter')
105105
@patch('sdgym.run_benchmark.upload_benchmark_results.os.environ.get')
106+
@patch('sdgym.run_benchmark.upload_benchmark_results.get_df_to_plot')
106107
def test_upload_results(
108+
mock_get_df_to_plot,
107109
mock_os_environ_get,
108110
mock_local_results_writer,
109111
mock_output_destination_aws,
@@ -124,6 +126,7 @@ def test_upload_results(
124126
result_explorer_instance.all_runs_complete.return_value = True
125127
result_explorer_instance.summarize.return_value = ('summary', 'results')
126128
mock_os_environ_get.return_value = '/tmp/sdgym_results'
129+
mock_get_df_to_plot.return_value = 'df_to_plot'
127130

128131
# Run
129132
upload_results(
@@ -149,9 +152,13 @@ def test_upload_results(
149152
result_explorer_instance.summarize.assert_called_once_with(run_name)
150153
mock_s3_results_writer.return_value.write_dataframe.assert_called_once()
151154
mock_write_uploaded_marker.assert_called_once_with(s3_client, bucket, prefix, run_name)
152-
mock_local_results_writer.return_value.write_dataframe.assert_called_once_with(
153-
'summary', '/tmp/sdgym_results/SDGym_results_10_01_2023_summary.csv', index=True
154-
)
155+
mock_local_results_writer.return_value.write_dataframe.assert_has_calls([
156+
call('summary', '/tmp/sdgym_results/SDGym_results_10_01_2023_summary.csv', index=True),
157+
call(
158+
'df_to_plot', '/tmp/sdgym_results/SDGym_results_10_01_2023_plot_data.csv', index=False
159+
),
160+
])
161+
mock_get_df_to_plot.assert_called_once_with('results')
155162

156163

157164
@patch('sdgym.run_benchmark.upload_benchmark_results.SDGymResultsExplorer')

0 commit comments

Comments
 (0)