Skip to content

Commit 6a8f451

Browse files
committed
fix gpu validation
1 parent fd48c41 commit 6a8f451

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

sdgym/_benchmark/benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ def _get_user_data_script(
159159
):
160160
compute_service = config['service']
161161
swap_gb = int(config.get('swap_gb', 32))
162-
gpu = bool(config.get('gpu', False))
162+
gpu = (
163+
bool(config.get('gpu'))
164+
or int(config.get('gpu_count', 0)) > 0
165+
or bool(config.get('gpu_type'))
166+
)
163167
upload_logs = bool(config.get('upload_logs', True))
164168

165169
aws_key = credentials['aws']['aws_access_key_id']

0 commit comments

Comments
 (0)