Skip to content
This repository was archived by the owner on Dec 16, 2025. It is now read-only.

Commit c8a2b04

Browse files
committed
Make get_cluster_resources_work
1 parent e641bb6 commit c8a2b04

File tree

1 file changed

+72
-10
lines changed

1 file changed

+72
-10
lines changed

src/lattice/providers/skypilot.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import requests
44
import json
5+
import re
56
from typing import Dict, Any, Optional, Union, List
67

78
# SkyPilot SDK imports - try to import, but allow graceful failure if not available
@@ -610,7 +611,7 @@ def get_cluster_status(self, cluster_name: str) -> ClusterStatus:
610611
last_use=cluster_data.get("last_use"),
611612
autostop=cluster_data.get("autostop"),
612613
num_nodes=cluster_data.get("num_nodes"),
613-
resources_str=cluster_data.get("resources_str"),
614+
resources_str=cluster_data.get("resources_str_full"),
614615
provider_data=cluster_data,
615616
)
616617

@@ -619,20 +620,81 @@ def get_cluster_resources(self, cluster_name: str) -> ResourceInfo:
619620
# SkyPilot doesn't have a dedicated resources endpoint,
620621
# so we get it from status
621622
status = self.get_cluster_status(cluster_name)
622-
resources_str = status.resources_str or ""
623-
624-
# Try to parse resources from resources_str
625-
# This is a simplified parser - may need enhancement
623+
624+
# Use resources_str_full if available, otherwise fall back to resources_str
625+
resources_str = status.provider_data.get("resources_str_full") or status.resources_str or ""
626+
627+
# Also check provider_data for direct resource fields
628+
provider_data = status.provider_data or {}
629+
num_nodes = status.num_nodes or provider_data.get("nodes") or 1
630+
631+
# Parse resources from resources_str_full
632+
# Format: "1x(gpus=RTX3090:1, cpus=4, mem=16, 4CPU--16GB--RTX3090:1, disk=256)"
626633
gpus = []
627-
if "GPU" in resources_str.upper():
628-
# Basic parsing - can be enhanced
629-
gpus.append({"type": "unknown", "count": 1})
634+
cpus = None
635+
memory_gb = None
636+
disk_gb = None
637+
638+
if resources_str:
639+
# Extract num_nodes from prefix (e.g., "1x(...)" or "2x(...)")
640+
node_match = re.match(r"(\d+)x\(", resources_str)
641+
if node_match:
642+
num_nodes = int(node_match.group(1))
643+
644+
# Extract GPUs: gpus=RTX3090:1 or gpus=V100:2
645+
gpu_match = re.search(r"gpus=([\w\d]+):(\d+)", resources_str)
646+
if gpu_match:
647+
gpu_type = gpu_match.group(1)
648+
gpu_count = int(gpu_match.group(2))
649+
gpus.append({"gpu": gpu_type, "count": gpu_count})
650+
651+
# Extract CPUs: cpus=4
652+
cpu_match = re.search(r"cpus=([\d.]+)", resources_str)
653+
if cpu_match:
654+
cpus = int(float(cpu_match.group(1)))
655+
656+
# Extract Memory: mem=16 (in GB)
657+
mem_match = re.search(r"mem=([\d.]+)", resources_str)
658+
if mem_match:
659+
memory_gb = float(mem_match.group(1))
660+
661+
# Extract Disk: disk=256 (in GB)
662+
disk_match = re.search(r"disk=([\d.]+)", resources_str)
663+
if disk_match:
664+
disk_gb = int(float(disk_match.group(1)))
665+
666+
# Also try to get from provider_data directly if available
667+
if not cpus and provider_data.get("cpus"):
668+
try:
669+
cpus = int(float(provider_data["cpus"]))
670+
except (ValueError, TypeError):
671+
pass
672+
673+
if not gpus and provider_data.get("accelerators"):
674+
try:
675+
# accelerators might be a string like "{'RTX3090': 1}" or a dict
676+
accel_str = provider_data["accelerators"]
677+
if isinstance(accel_str, str):
678+
# Try to parse string representation
679+
import ast
680+
accel_dict = ast.literal_eval(accel_str)
681+
else:
682+
accel_dict = accel_str
683+
684+
if isinstance(accel_dict, dict):
685+
for gpu_type, count in accel_dict.items():
686+
gpus.append({"gpu": gpu_type, "count": int(count)})
687+
except (ValueError, TypeError, SyntaxError):
688+
pass
630689

631690
return ResourceInfo(
632691
cluster_name=cluster_name,
633692
gpus=gpus,
634-
num_nodes=status.num_nodes,
635-
provider_data={"resources_str": resources_str},
693+
cpus=cpus,
694+
memory_gb=memory_gb,
695+
disk_gb=disk_gb,
696+
num_nodes=num_nodes,
697+
provider_data={"resources_str": resources_str, **provider_data},
636698
)
637699

638700
def submit_job(self, cluster_name: str, job_config: JobConfig) -> Dict[str, Any]:

0 commit comments

Comments
 (0)