Skip to content

Commit 38922c1

Browse files
committed
address comments
1 parent 9f10efc commit 38922c1

File tree

12 files changed

+162
-48
lines changed

12 files changed

+162
-48
lines changed

.github/workflows/run_benchmark.yml

Lines changed: 0 additions & 31 deletions
This file was deleted.

.github/workflows/run_benchmark_multi_table.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ jobs:
3131
source venv/bin/activate
3232
3333
python -m pip install --upgrade pip
34-
python -m pip install sdv-enterprise --index-url "https://${USERNAME}:${LICENSE_KEY}@pypi.datacebo.com"
34+
python -m pip install invoke
35+
invoke install-sdv-enterprise
3536
python -m pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-516-add-workflows"
3637
3738
echo "VIRTUAL_ENV=$(pwd)/venv" >> $GITHUB_ENV
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: Run SDGym Benchmark Single-Table
2+
3+
on:
4+
workflow_dispatch:
5+
schedule:
6+
- cron: '0 5 1 * *'
7+
8+
jobs:
9+
run-sdgym-benchmark:
10+
runs-on: ubuntu-latest
11+
12+
steps:
13+
- uses: actions/checkout@v4
14+
with:
15+
fetch-depth: 0
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v5
19+
with:
20+
python-version-file: 'pyproject.toml'
21+
22+
- name: Install dependencies
23+
env:
24+
USERNAME: ${{ secrets.SDV_ENTERPRISE_USERNAME }}
25+
LICENSE_KEY: ${{ secrets.SDV_ENTERPRISE_LICENSE_KEY }}
26+
run: |
27+
python -m venv venv
28+
source venv/bin/activate
29+
30+
python -m pip install --upgrade pip
31+
python -m pip install invoke
32+
invoke install-sdv-enterprise
33+
python -m pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-516-add-workflows"
34+
35+
echo "VIRTUAL_ENV=$(pwd)/venv" >> $GITHUB_ENV
36+
echo "$(pwd)/venv/bin" >> $GITHUB_PATH
37+
38+
- name: Run SDGym Benchmark
39+
env:
40+
GCP_SERVICE_ACCOUNT_JSON: ${{ secrets.GCP_SERVICE_ACCOUNT_JSON }}
41+
SDV_ENTERPRISE_USERNAME: ${{ secrets.SDV_ENTERPRISE_USERNAME }}
42+
SDV_ENTERPRISE_LICENSE_KEY: ${{ secrets.SDV_ENTERPRISE_LICENSE_KEY }}
43+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
44+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
45+
SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
46+
run: |
47+
export CREDENTIALS_FILEPATH=$(python -c "from sdgym._benchmark.credentials_utils import create_credentials_file; print(create_credentials_file())")
48+
invoke run-sdgym-benchmark --modality single_table
49+
rm -f "$CREDENTIALS_FILEPATH"

.github/workflows/upload_benchmark_results.yml renamed to .github/workflows/upload_benchmark_results_single_table.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
name: Upload SDGym Benchmark results
1+
name: Upload SDGym Single-Table Benchmark results
22

33
on:
44
workflow_run:
5-
workflows: ["Run SDGym Benchmark"]
5+
workflows: ["Run SDGym Benchmark Single-Table"]
66
types:
77
- completed
88
workflow_dispatch:
@@ -35,7 +35,7 @@ jobs:
3535
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
3636
GITHUB_LOCAL_RESULTS_DIR: ${{ runner.temp }}/sdgym-leaderboard-files
3737
run: |
38-
invoke upload-benchmark-results
38+
invoke upload-benchmark-results --modality single_table
3939
echo "GITHUB_LOCAL_RESULTS_DIR=$GITHUB_LOCAL_RESULTS_DIR" >> $GITHUB_ENV
4040
4141
- name: Prepare files for commit

sdgym/_benchmark/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,6 @@ def validate_compute_config(config):
118118

119119

120120
def _make_instance_name(prefix):
121-
day = datetime.now(timezone.utc).strftime('%Y_%m_%d_%H:%M')
121+
day = datetime.now(timezone.utc).strftime('%Y%m%d-%H%M')
122122
suffix = uuid.uuid4().hex[:6]
123123
return f'{prefix}-{day}-{suffix}'

sdgym/run_benchmark/upload_benchmark_results.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616

1717
from sdgym.result_explorer.result_explorer import ResultsExplorer
1818
from sdgym.result_writer import LocalResultsWriter
19-
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, _parse_args, get_df_to_plot
19+
from sdgym.run_benchmark.utils import (
20+
MODALITY_TO_GDRIVE_LINK,
21+
OUTPUT_DESTINATION_AWS,
22+
_extract_google_file_id,
23+
_parse_args,
24+
get_df_to_plot,
25+
)
2026
from sdgym.s3 import S3_REGION, parse_s3_path
2127

2228
LOGGER = logging.getLogger(__name__)
@@ -29,10 +35,6 @@
2935
'CopulaGAN': 'top center',
3036
'RealTabFormer': 'bottom center',
3137
}
32-
MODALITY_TO_FILE_ID = {
33-
'single_table': '1W3tsGOOtbtTw3g0EVE0irLgY_TN_cy2W4ONiZQ57OPo',
34-
'multi_table': '1R13RktVvKnxRecYIge07OBpbX1vbEkE2D1_2idNAKSY',
35-
}
3638
RESULT_FILENAME = 'SDGym Monthly Run.xlsx'
3739

3840

@@ -171,7 +173,7 @@ def upload_results(
171173
f'{run_date}_plot_data': df_to_plot,
172174
}
173175
local_results_writer.write_xlsx(datas, local_file_path)
174-
upload_to_drive((local_file_path), MODALITY_TO_FILE_ID[modality])
176+
upload_to_drive((local_file_path), _extract_google_file_id(MODALITY_TO_GDRIVE_LINK[modality]))
175177
s3_client.upload_file(local_file_path, bucket, s3_key)
176178
write_uploaded_marker(s3_client, bucket, prefix, folder_name, modality=modality)
177179
if temp_dir:

sdgym/run_benchmark/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import argparse
44
import os
55
from datetime import datetime
6-
from urllib.parse import quote_plus
6+
from urllib.parse import parse_qs, quote_plus, urlparse
77

88
import numpy as np
99
from slack_sdk import WebClient
@@ -186,3 +186,16 @@ def _parse_args():
186186
help='Benchmark modality to run.',
187187
)
188188
return parser.parse_args()
189+
190+
191+
def _extract_google_file_id(google_drive_link: str) -> str:
192+
parsed = urlparse(google_drive_link)
193+
file_id = parse_qs(parsed.query).get('id')
194+
if file_id:
195+
return file_id[0]
196+
197+
for marker in ('/d/', '/file/d/'):
198+
if marker in parsed.path:
199+
return parsed.path.split(marker, 1)[1].split('/', 1)[0]
200+
201+
raise ValueError(f'Invalid Google Drive link format: {google_drive_link}')

tasks.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import stat
66
import sys
77
from pathlib import Path
8+
from sdgym._benchmark.credentials_utils import sdv_install_cmd
89

910
import tomli
1011
from invoke import task
@@ -217,4 +218,22 @@ def notify_sdgym_benchmark_uploaded(c, folder_name, commit_url=None, modality='s
217218
"""Notify Slack about the SDGym benchmark upload."""
218219
from sdgym.run_benchmark.utils import post_benchmark_uploaded_message
219220

220-
post_benchmark_uploaded_message(folder_name, commit_url, modality)
221+
post_benchmark_uploaded_message(folder_name, commit_url, modality)
222+
223+
@task
224+
def install_sdv_enterprise(c, username=None, license_key=None):
225+
"""Install sdv-enterprise using sdv-installer if credentials are available."""
226+
username = username or os.getenv("SDV_ENTERPRISE_USERNAME")
227+
license_key = license_key or os.getenv("SDV_ENTERPRISE_LICENSE_KEY")
228+
credentials = {
229+
"sdv": {
230+
"username": username,
231+
"license_key": license_key,
232+
}
233+
}
234+
235+
install_cmd = sdv_install_cmd(credentials)
236+
if install_cmd:
237+
c.run(install_cmd)
238+
else:
239+
print("No sdv-enterprise credentials found. Skipping installation.")

tests/test_tasks.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
"""Tests for the ``tasks.py`` file."""
22

3-
from tasks import _get_extra_dependencies, _get_minimum_versions, _resolve_version_conflicts
3+
from unittest.mock import Mock, patch
4+
5+
from tasks import (
6+
_get_extra_dependencies,
7+
_get_minimum_versions,
8+
_resolve_version_conflicts,
9+
install_sdv_enterprise,
10+
)
411

512

613
def test_get_minimum_versions():
@@ -205,3 +212,26 @@ def test__resolve_version_conflicts_pointing_to_branch():
205212
'rdt==1.1.2',
206213
'copulas==0.12.0',
207214
])
215+
216+
217+
@patch('tasks.sdv_install_cmd')
218+
def test_install_sdv_enterprise(mock_sdv_install_cmd):
219+
"""Test the `install_sdv_enterprise` task."""
220+
# Setup
221+
username = 'test_user'
222+
license_key = 'test_license_key'
223+
mock_sdv_install_cmd.return_value = 'install command'
224+
mock_context = Mock()
225+
226+
# Run
227+
install_sdv_enterprise(mock_context, username=username, license_key=license_key)
228+
229+
# Assert
230+
mock_sdv_install_cmd.assert_called_once_with({
231+
'sdv': {
232+
'username': username,
233+
'license_key': license_key,
234+
}
235+
})
236+
mock_context.run.assert_called_once_with('install command')
237+
assert False

tests/unit/_benchmark/test_config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_make_instance_name(mock_datetime, mock_uuid):
106106
result = _make_instance_name('sdgym-run')
107107

108108
# Assert
109-
assert result == 'sdgym-run-2025_01_15_12:00-abcdef'
109+
assert result == 'sdgym-run-20250115-1200-abcdef'
110110

111111

112112
@patch('sdgym._benchmark.config_utils._apply_compute_service_keymap')

0 commit comments

Comments
 (0)