11DEFAULT_COMPUTE_CONFIG = {
22 'common' : {
3+ 'name_prefix' : 'sdgym-run' ,
4+ 'root_disk_gb' : 100 ,
5+ 'compute_type' : None ,
6+ 'boot_image' : None ,
7+ 'gpu_type' : None ,
8+ 'gpu_count' : 0 ,
39 'swap_gb' : 32 ,
4- 'disk_size_gb' : 100 ,
510 'sdgym_install' : (
611 'sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@gcp-benchmark-romain'
712 ),
1217 'upload_logs_to_s3' : True ,
1318 },
1419 'gcp' : {
15- 'name_prefix' : 'sdgym-run' ,
16- 'machine_type' : 'n1-standard-8' ,
17- 'source_image' : (
20+ 'compute_type' : 'n1-standard-8' ,
21+ 'boot_image' : (
1822 'projects/deeplearning-platform-release/global/images/family/'
1923 'common-cu128-ubuntu-2204-nvidia-570'
2024 ),
2630 'stop_fallback' : True ,
2731 },
2832 'aws' : {
29- 'name_prefix' : 'sdgym-run' ,
30- 'ami' : 'ami-080e1f13689e07408' ,
31- 'instance_type' : 'g4dn.4xlarge' ,
32- 'volume_size_gb' : 100 ,
33+ 'compute_type' : 'g4dn.4xlarge' ,
34+ 'boot_image' : 'ami-080e1f13689e07408' ,
3335 },
3436}
3537
38+ _KEYMAP_COMPUTE_SERVICE = {
39+ 'root_disk_gb' : {
40+ 'aws' : 'volume_size_gb' ,
41+ 'gcp' : 'disk_size_gb' ,
42+ },
43+ 'compute_type' : {
44+ 'aws' : 'instance_type' ,
45+ 'gcp' : 'machine_type' ,
46+ },
47+ 'boot_image' : {
48+ 'aws' : 'ami' ,
49+ 'gcp' : 'source_image' ,
50+ },
51+ }
52+ _REQUIRED_CANONICAL_KEYS = (
53+ 'compute_type' ,
54+ 'boot_image' ,
55+ 'root_disk_gb' ,
56+ )
57+
3658
3759def _merge_dict (base , config ):
3860 out = dict (base )
@@ -45,37 +67,51 @@ def _merge_dict(base, config):
4567 return out
4668
4769
70+ def _apply_compute_service_keymap (config ):
71+ """Expand canonical keys into provider-specific keys."""
72+ compute_service = config ['service' ]
73+ out = dict (config )
74+ for canonical_key , per_service in _KEYMAP_COMPUTE_SERVICE .items ():
75+ if canonical_key not in out :
76+ continue
77+
78+ provider_key = per_service .get (compute_service )
79+ if provider_key :
80+ out [provider_key ] = out [canonical_key ]
81+
82+ return out
83+
84+
4885def resolve_compute_config (compute_service , config = None ):
4986 if compute_service not in ('aws' , 'gcp' ):
5087 raise ValueError ("compute_service must be 'aws' or 'gcp'" )
5188
52- base = _merge_dict (DEFAULT_COMPUTE_CONFIG ['common' ], DEFAULT_COMPUTE_CONFIG [compute_service ])
89+ base = _merge_dict (
90+ DEFAULT_COMPUTE_CONFIG ['common' ],
91+ DEFAULT_COMPUTE_CONFIG [compute_service ],
92+ )
5393 base ['service' ] = compute_service
54- return _merge_dict (base , config )
55-
56-
57- def validate_compute_config (credentials , config ):
58- # Always needed because results/logs go to S3
59- aws = credentials .get ('aws' ) or {}
60- if not aws .get ('aws_access_key_id' ) or not aws .get ('aws_secret_access_key' ):
61- raise ValueError ("Missing AWS credentials in credentials['aws']" )
62-
63- svc = config ['service' ]
64- if svc == 'gcp' :
65- gcp = credentials .get ('gcp' ) or {}
66- if not gcp .get ('gcp_project' ) or not gcp .get ('gcp_zone' ):
67- raise ValueError (
68- "Missing GCP fields: credentials['gcp']['gcp_project'] and ['gcp_zone']"
69- )
70- for k in ('machine_type' , 'source_image' , 'disk_size_gb' ):
71- if not config .get (k ):
72- raise ValueError (f'Missing required GCP config field: { k } ' )
73-
74- # If you expect GPU, require gpu_type/count
75- if config .get ('gpu_count' , 0 ):
76- if not config .get ('gpu_type' ):
77- raise ValueError ('Missing required GCP config field: gpu_type (GPU requested)' )
78- elif svc == 'aws' :
79- for k in ('ami' , 'instance_type' , 'volume_size_gb' ):
80- if not config .get (k ):
81- raise ValueError (f'Missing required AWS config field: { k } ' )
94+ merged = _merge_dict (base , config )
95+ resolved = _apply_compute_service_keymap (merged )
96+
97+ return resolved
98+
99+
100+ def validate_compute_config (config ):
101+ service = config .get ('service' )
102+ if service not in ('gcp' , 'aws' ):
103+ raise ValueError (
104+ f'Invalid compute config: unknown service={ service !r} . Expected one of: gcp, aws'
105+ )
106+
107+ missing = [key for key in _REQUIRED_CANONICAL_KEYS if not config .get (key )]
108+ gpu_count = int (config .get ('gpu_count' ) or 0 )
109+ if gpu_count > 0 and not config .get ('gpu_type' ):
110+ missing .append ('gpu_type' )
111+
112+ if missing :
113+ missing = "' ,'" .join (missing )
114+ raise ValueError (
115+ f'Invalid compute config for service={ service !r} . '
116+ f"Missing required field(s): '{ missing } '."
117+ )
0 commit comments