Skip to content

Commit 50f5ecb

Browse files
committed
add unit tests
1 parent 9f61995 commit 50f5ecb

File tree

4 files changed

+703
-70
lines changed

4 files changed

+703
-70
lines changed

sdgym/_benchmark/benchmark.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import textwrap
2-
import time
32
import uuid
3+
from datetime import datetime, timezone
44
from urllib.parse import urlparse
55

66
from google.cloud import compute_v1
@@ -25,9 +25,9 @@
2525

2626

2727
def _make_instance_name(prefix):
28-
timestamp = int(time.time())
28+
day = datetime.now(timezone.utc).strftime('%Y%m%d')
2929
suffix = uuid.uuid4().hex[:6]
30-
return f'{prefix}-{timestamp}-{suffix}'
30+
return f'{prefix}-{day}-{suffix}'
3131

3232

3333
def _logs_s3_uri(output_destination, instance_name):
@@ -357,16 +357,8 @@ def _get_user_data_script(
357357

358358

359359
def _run_on_gcp(
360-
output_destination,
361-
synthesizers,
362-
s3_client,
363-
job_args_list,
364-
credentials,
365-
config_overrides=None,
360+
output_destination, synthesizers, s3_client, job_args_list, credentials, compute_config
366361
):
367-
config = resolve_compute_config('gcp', config_overrides)
368-
validate_compute_config(credentials, config)
369-
370362
script_content = _prepare_script_content(
371363
output_destination,
372364
synthesizers,
@@ -381,33 +373,31 @@ def _run_on_gcp(
381373
credentials['gcp'],
382374
)
383375

384-
instance_name = _make_instance_name(config['name_prefix'])
376+
instance_name = _make_instance_name(compute_config['name_prefix'])
385377
print( # noqa: T201
386378
f'Launching instance: {instance_name} (service=gcp project={gcp_project} zone={gcp_zone})'
387379
)
388-
389380
startup_script = _get_user_data_script(
390381
credentials,
391382
script_content,
392-
config,
383+
compute_config,
393384
instance_name,
394385
output_destination,
395386
)
396387

397-
machine_type = f'zones/{gcp_zone}/machineTypes/{config["machine_type"]}'
398-
source_disk_image = config['source_image']
399-
388+
machine_type = f'zones/{gcp_zone}/machineTypes/{compute_config["machine_type"]}'
389+
source_disk_image = compute_config['source_image']
400390
gpu = compute_v1.AcceleratorConfig(
401-
accelerator_type=(f'zones/{gcp_zone}/acceleratorTypes/{config["gpu_type"]}'),
402-
accelerator_count=int(config['gpu_count']),
391+
accelerator_type=(f'zones/{gcp_zone}/acceleratorTypes/{compute_config["gpu_type"]}'),
392+
accelerator_count=int(compute_config['gpu_count']),
403393
)
404394

405395
boot_disk = compute_v1.AttachedDisk(
406396
auto_delete=True,
407397
boot=True,
408398
initialize_params=compute_v1.AttachedDiskInitializeParams(
409399
source_image=source_disk_image,
410-
disk_size_gb=int(config['disk_size_gb']),
400+
disk_size_gb=int(compute_config['disk_size_gb']),
411401
),
412402
)
413403

@@ -421,7 +411,7 @@ def _run_on_gcp(
421411
]
422412

423413
items = [compute_v1.Items(key='startup-script', value=startup_script)]
424-
if config.get('install_nvidia_driver', True):
414+
if compute_config.get('install_nvidia_driver', True):
425415
items.append(
426416
compute_v1.Items(key='install-nvidia-driver', value='true'),
427417
)
@@ -489,6 +479,7 @@ def _benchmark_compute_gcp(
489479
"""Run the SDGym benchmark on datasets for the given modality."""
490480
compute_config = resolve_compute_config('gcp', compute_config)
491481
credentials = get_credentials(credential_filepath)
482+
validate_compute_config(compute_config)
492483

493484
s3_client = _validate_output_destination(
494485
output_destination,
@@ -537,9 +528,8 @@ def _benchmark_compute_gcp(
537528
s3_client=s3_client,
538529
job_args_list=job_args_list,
539530
credentials=credentials,
540-
config_overrides=compute_config,
531+
compute_config=compute_config,
541532
)
542-
return None
543533

544534

545535
def _benchmark_single_table_compute_gcp(

sdgym/_benchmark/config_utils.py

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
DEFAULT_COMPUTE_CONFIG = {
22
'common': {
3+
'name_prefix': 'sdgym-run',
4+
'root_disk_gb': 100,
5+
'compute_type': None,
6+
'boot_image': None,
7+
'gpu_type': None,
8+
'gpu_count': 0,
39
'swap_gb': 32,
4-
'disk_size_gb': 100,
510
'sdgym_install': (
611
'sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@gcp-benchmark-romain'
712
),
@@ -12,9 +17,8 @@
1217
'upload_logs_to_s3': True,
1318
},
1419
'gcp': {
15-
'name_prefix': 'sdgym-run',
16-
'machine_type': 'n1-standard-8',
17-
'source_image': (
20+
'compute_type': 'n1-standard-8',
21+
'boot_image': (
1822
'projects/deeplearning-platform-release/global/images/family/'
1923
'common-cu128-ubuntu-2204-nvidia-570'
2024
),
@@ -26,13 +30,31 @@
2630
'stop_fallback': True,
2731
},
2832
'aws': {
29-
'name_prefix': 'sdgym-run',
30-
'ami': 'ami-080e1f13689e07408',
31-
'instance_type': 'g4dn.4xlarge',
32-
'volume_size_gb': 100,
33+
'compute_type': 'g4dn.4xlarge',
34+
'boot_image': 'ami-080e1f13689e07408',
3335
},
3436
}
3537

38+
_KEYMAP_COMPUTE_SERVICE = {
39+
'root_disk_gb': {
40+
'aws': 'volume_size_gb',
41+
'gcp': 'disk_size_gb',
42+
},
43+
'compute_type': {
44+
'aws': 'instance_type',
45+
'gcp': 'machine_type',
46+
},
47+
'boot_image': {
48+
'aws': 'ami',
49+
'gcp': 'source_image',
50+
},
51+
}
52+
_REQUIRED_CANONICAL_KEYS = (
53+
'compute_type',
54+
'boot_image',
55+
'root_disk_gb',
56+
)
57+
3658

3759
def _merge_dict(base, config):
3860
out = dict(base)
@@ -45,37 +67,51 @@ def _merge_dict(base, config):
4567
return out
4668

4769

70+
def _apply_compute_service_keymap(config):
71+
"""Expand canonical keys into provider-specific keys."""
72+
compute_service = config['service']
73+
out = dict(config)
74+
for canonical_key, per_service in _KEYMAP_COMPUTE_SERVICE.items():
75+
if canonical_key not in out:
76+
continue
77+
78+
provider_key = per_service.get(compute_service)
79+
if provider_key:
80+
out[provider_key] = out[canonical_key]
81+
82+
return out
83+
84+
4885
def resolve_compute_config(compute_service, config=None):
4986
if compute_service not in ('aws', 'gcp'):
5087
raise ValueError("compute_service must be 'aws' or 'gcp'")
5188

52-
base = _merge_dict(DEFAULT_COMPUTE_CONFIG['common'], DEFAULT_COMPUTE_CONFIG[compute_service])
89+
base = _merge_dict(
90+
DEFAULT_COMPUTE_CONFIG['common'],
91+
DEFAULT_COMPUTE_CONFIG[compute_service],
92+
)
5393
base['service'] = compute_service
54-
return _merge_dict(base, config)
55-
56-
57-
def validate_compute_config(credentials, config):
58-
# Always needed because results/logs go to S3
59-
aws = credentials.get('aws') or {}
60-
if not aws.get('aws_access_key_id') or not aws.get('aws_secret_access_key'):
61-
raise ValueError("Missing AWS credentials in credentials['aws']")
62-
63-
svc = config['service']
64-
if svc == 'gcp':
65-
gcp = credentials.get('gcp') or {}
66-
if not gcp.get('gcp_project') or not gcp.get('gcp_zone'):
67-
raise ValueError(
68-
"Missing GCP fields: credentials['gcp']['gcp_project'] and ['gcp_zone']"
69-
)
70-
for k in ('machine_type', 'source_image', 'disk_size_gb'):
71-
if not config.get(k):
72-
raise ValueError(f'Missing required GCP config field: {k}')
73-
74-
# If you expect GPU, require gpu_type/count
75-
if config.get('gpu_count', 0):
76-
if not config.get('gpu_type'):
77-
raise ValueError('Missing required GCP config field: gpu_type (GPU requested)')
78-
elif svc == 'aws':
79-
for k in ('ami', 'instance_type', 'volume_size_gb'):
80-
if not config.get(k):
81-
raise ValueError(f'Missing required AWS config field: {k}')
94+
merged = _merge_dict(base, config)
95+
resolved = _apply_compute_service_keymap(merged)
96+
97+
return resolved
98+
99+
100+
def validate_compute_config(config):
101+
service = config.get('service')
102+
if service not in ('gcp', 'aws'):
103+
raise ValueError(
104+
f'Invalid compute config: unknown service={service!r}. Expected one of: gcp, aws'
105+
)
106+
107+
missing = [key for key in _REQUIRED_CANONICAL_KEYS if not config.get(key)]
108+
gpu_count = int(config.get('gpu_count') or 0)
109+
if gpu_count > 0 and not config.get('gpu_type'):
110+
missing.append('gpu_type')
111+
112+
if missing:
113+
missing = "' ,'".join(missing)
114+
raise ValueError(
115+
f'Invalid compute config for service={service!r}. '
116+
f"Missing required field(s): '{missing}'."
117+
)

0 commit comments

Comments
 (0)