Skip to content

Commit 279a867

Browse files
committed
address comments
1 parent 14150e3 commit 279a867

File tree

6 files changed

+207
-132
lines changed

6 files changed

+207
-132
lines changed

sdgym/_benchmark/benchmark.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import textwrap
2-
import uuid
3-
from datetime import datetime, timezone
42
from urllib.parse import urlparse
53

64
from google.cloud import compute_v1
75
from google.oauth2 import service_account
86

9-
from sdgym._benchmark.config_utils import resolve_compute_config, validate_compute_config
7+
from sdgym._benchmark.config_utils import (
8+
_make_instance_name,
9+
resolve_compute_config,
10+
validate_compute_config,
11+
)
1012
from sdgym._benchmark.credentials_utils import get_credentials, sdv_install_cmd
1113
from sdgym.benchmark import (
1214
DEFAULT_MULTI_TABLE_DATASETS,
@@ -24,13 +26,7 @@
2426
)
2527

2628

27-
def _make_instance_name(prefix):
28-
day = datetime.now(timezone.utc).strftime('%Y%m%d')
29-
suffix = uuid.uuid4().hex[:6]
30-
return f'{prefix}-{day}-{suffix}'
31-
32-
33-
def _logs_s3_uri(output_destination, instance_name):
29+
def _get_logs_s3_uri(output_destination, instance_name):
3430
"""Store logs next to output destination prefix.
3531
3632
Example:
@@ -156,9 +152,10 @@ def _get_user_data_script(
156152
aws_key = credentials['aws']['aws_access_key_id']
157153
aws_secret = credentials['aws']['aws_secret_access_key']
158154

159-
log_uri = _logs_s3_uri(output_destination, instance_name) if upload_logs else ''
155+
log_uri = _get_logs_s3_uri(output_destination, instance_name) if upload_logs else ''
160156

161-
sdv_install = sdv_install_cmd(credentials)
157+
sdv_install = sdv_install_cmd(credentials).rstrip()
158+
sdv_install = textwrap.indent(sdv_install, ' ') if sdv_install else ''
162159
terminate_fn = _terminate_instance(compute_service)
163160
upload_logs_fn = _upload_logs(log_uri)
164161
gpu_block = _gpu_wait_block() if gpu else ''

sdgym/_benchmark/config_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1+
import uuid
2+
from datetime import datetime, timezone
3+
14
DEFAULT_COMPUTE_CONFIG = {
25
'common': {
36
'name_prefix': 'sdgym-run',
4-
'root_disk_gb': 100,
7+
'root_disk_gb': 300,
58
'compute_type': None,
69
'boot_image': None,
710
'gpu_type': None,
811
'gpu_count': 0,
9-
'swap_gb': 32,
12+
'swap_gb': 64,
1013
'install_s3fs': True,
1114
'assert_gpu': True,
1215
'gpu_wait_seconds': 10 * 60,
1316
'gpu_wait_interval_seconds': 10,
1417
'upload_logs_to_s3': True,
1518
},
1619
'gcp': {
17-
'compute_type': 'n1-standard-16',
20+
'compute_type': 'n1-highmem-16',
1821
'boot_image': (
1922
'projects/deeplearning-platform-release/global/images/family/'
2023
'common-cu128-ubuntu-2204-nvidia-570'
@@ -112,3 +115,9 @@ def validate_compute_config(config):
112115
f'Invalid compute config for service={service!r}. '
113116
f"Missing required field(s): '{missing}'."
114117
)
118+
119+
120+
def _make_instance_name(prefix):
121+
day = datetime.now(timezone.utc).strftime('%Y_%m_%d_%H:%M')
122+
suffix = uuid.uuid4().hex[:6]
123+
return f'{prefix}-{day}-{suffix}'

sdgym/_benchmark/credentials_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import textwrap
23

34
CREDENTIAL_KEYS = {
45
'aws': {'aws_access_key_id', 'aws_secret_access_key'},
@@ -60,13 +61,16 @@ def get_credentials(credential_filepath):
6061

6162

6263
def sdv_install_cmd(credentials):
64+
"""Return the shell command to install sdv-enterprise using sdv-installer."""
6365
sdv_creds = credentials.get('sdv') or {}
6466
username = sdv_creds.get('username')
6567
license_key = sdv_creds.get('license_key')
6668
if not (username and license_key):
6769
return ''
6870

69-
return (
70-
'pip install bundle-xsynthesizers '
71-
f'--index-url https://{username}:{license_key}@pypi.datacebo.com'
72-
)
71+
return textwrap.dedent(f"""\
72+
pip install sdv-installer
73+
74+
python -c "from sdv_installer.installation.installer import install_packages; \\
75+
install_packages(username='{username}', license_key='{license_key}', package='sdv-enterprise')"
76+
""")

tests/unit/_benchmark/test_benchmark.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from datetime import datetime, timezone
21
from unittest.mock import Mock, patch
32

43
import pytest
@@ -9,7 +8,6 @@
98
_benchmark_single_table_compute_gcp,
109
_get_user_data_script,
1110
_gpu_wait_block,
12-
_make_instance_name,
1311
_run_on_gcp,
1412
_terminate_instance,
1513
_upload_logs,
@@ -36,21 +34,6 @@ def base_credentials():
3634
}
3735

3836

39-
@patch('sdgym._benchmark.benchmark.uuid.uuid4')
40-
@patch('sdgym._benchmark.benchmark.datetime')
41-
def test_make_instance_name(mock_datetime, mock_uuid):
42-
"""Test `_make_instance_name` generates a stable, readable name."""
43-
# Setup
44-
mock_datetime.now.return_value = datetime(2025, 1, 15, tzinfo=timezone.utc)
45-
mock_uuid.return_value.hex = 'abcdef123456'
46-
47-
# Run
48-
result = _make_instance_name('sdgym-run')
49-
50-
# Assert
51-
assert result == 'sdgym-run-20250115-abcdef'
52-
53-
5437
def test_terminate_instance_aws():
5538
"""AWS termination script self-terminates via EC2 metadata and AWS CLI."""
5639
# Run
@@ -79,16 +62,16 @@ def test_terminate_instance_gcp():
7962
assert 'Metadata-Flavor: Google' not in script
8063

8164

82-
def test_terminate_instance_invalid_service():
65+
def test__terminate_instance_invalid_service():
8366
"""Invalid compute service raises a clear error."""
8467
# Run and Assert
8568
with pytest.raises(ValueError, match='Unsupported compute service'):
8669
_terminate_instance('azure')
8770

8871

89-
def test_gpu_wait_block_contents():
72+
def test__gpu_wait_block_contents():
9073
"""GPU wait block waits for nvidia-smi to become available."""
91-
# Setup
74+
# Run
9275
block = _gpu_wait_block()
9376

9477
# Assert
@@ -98,7 +81,7 @@ def test_gpu_wait_block_contents():
9881
assert 'for i in' in block or 'while' in block
9982

10083

101-
def test_upload_logs_fn_no_uri():
84+
def test__upload_logs_fn_no_uri():
10285
"""No log URI returns a no-op upload_logs function."""
10386
# Run
10487
fn = _upload_logs('')
@@ -107,7 +90,7 @@ def test_upload_logs_fn_no_uri():
10790
assert fn.strip() == 'upload_logs() { :; }'
10891

10992

110-
def test_upload_logs_fn_with_uri():
93+
def test__upload_logs_fn_with_uri():
11194
"""Upload logs function uploads user-data.log to S3."""
11295
# Setup
11396
uri = 's3://bucket/prefix/logs/instance-user-data.log'
@@ -121,7 +104,7 @@ def test_upload_logs_fn_with_uri():
121104
assert uri in fn
122105

123106

124-
def test_get_user_data_script_gcp_gpu_wait(base_credentials):
107+
def test__get_user_data_script_gcp_gpu_wait(base_credentials):
125108
"""Test GCP user-data script includes GPU wait and delete logic."""
126109
# Setup
127110
config = {
@@ -141,6 +124,7 @@ def test_get_user_data_script_gcp_gpu_wait(base_credentials):
141124
'upload_logs_to_s3': True,
142125
}
143126

127+
# Run
144128
script = _get_user_data_script(
145129
credentials=base_credentials,
146130
script_content="print('hello')",
@@ -159,7 +143,7 @@ def test_get_user_data_script_gcp_gpu_wait(base_credentials):
159143
assert "print('hello')" in script
160144

161145

162-
def test_get_user_data_script_aws_termination(base_credentials):
146+
def test__get_user_data_script_aws_termination(base_credentials):
163147
"""Test AWS user-data script includes EC2 termination logic."""
164148
# Setup
165149
config = {
@@ -173,6 +157,7 @@ def test_get_user_data_script_aws_termination(base_credentials):
173157
'upload_logs_to_s3': True,
174158
}
175159

160+
# Run
176161
script = _get_user_data_script(
177162
credentials=base_credentials,
178163
script_content="print('aws')",
@@ -284,11 +269,44 @@ def test_run_on_gcp(
284269
)
285270
],
286271
)
287-
mock_instances_client.insert.assert_called_once()
288-
mock_compute_v1.ZoneOperationsClient.assert_called_once()
289-
mock_zone_ops_client.wait.assert_called_once()
290-
mock_compute_v1.Metadata.assert_called_once()
291-
mock_compute_v1.Instance.assert_called_once()
272+
mock_instances_client.insert.assert_called_once_with(
273+
project='test-project',
274+
zone='us-central1-a',
275+
instance_resource=mock_compute_v1.Instance.return_value,
276+
)
277+
mock_compute_v1.ZoneOperationsClient.assert_called_once_with(credentials=gcp_cred)
278+
mock_zone_ops_client.wait.assert_called_once_with(
279+
project='test-project',
280+
zone='us-central1-a',
281+
operation=mock_instances_client.insert.return_value.name,
282+
)
283+
mock_compute_v1.Metadata.assert_called_once_with(
284+
items=[
285+
mock_compute_v1.Items(
286+
key='startup-script',
287+
value='STARTUP_SCRIPT',
288+
),
289+
mock_compute_v1.Items(
290+
key='enable-oslogin',
291+
value='TRUE',
292+
),
293+
]
294+
)
295+
mock_compute_v1.Instance.assert_called_once_with(
296+
name='instance-123',
297+
machine_type='zones/us-central1-a/machineTypes/n1-standard-4',
298+
disks=[boot_disk],
299+
network_interfaces=[nic],
300+
metadata=metadata,
301+
guest_accelerators=[gpu],
302+
scheduling=scheduling,
303+
service_accounts=[
304+
mock_compute_v1.ServiceAccount(
305+
email='default',
306+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
307+
)
308+
],
309+
)
292310

293311

294312
@patch('sdgym._benchmark.benchmark._run_on_gcp')

0 commit comments

Comments
 (0)