Skip to content
This repository was archived by the owner on Dec 16, 2025. It is now read-only.

Commit dd06f46

Browse files
committed
Get skypilot status working
1 parent b5f8440 commit dd06f46

File tree

2 files changed

+107
-46
lines changed

2 files changed

+107
-46
lines changed

src/lattice/providers/example.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,37 +44,37 @@ def example_skypilot():
4444
status = provider.get_cluster_status("my-cluster")
4545
print(f" Status: {status.state}, Message: {status.status_message}\n")
4646

47-
# Get cluster resources
48-
print("3. Getting cluster resources...")
49-
resources = provider.get_cluster_resources("my-cluster")
50-
print(f" Resources: {resources.num_nodes} nodes, GPUs: {resources.gpus}\n")
51-
52-
# Submit a job
53-
print("4. Submitting job...")
54-
job_config = JobConfig(
55-
command="python train.py",
56-
job_name="training-job",
57-
env_vars={"CUDA_VISIBLE_DEVICES": "0"},
58-
)
59-
job_result = provider.submit_job("my-cluster", job_config)
60-
print(f" Job ID: {job_result.get('job_id')}\n")
61-
62-
# List jobs
63-
print("5. Listing jobs...")
64-
jobs = provider.list_jobs("my-cluster")
65-
for job in jobs:
66-
print(f" Job {job.job_id}: {job.state} - {job.job_name}")
67-
68-
# Get job logs
69-
if jobs:
70-
print(f"\n6. Getting logs for job {jobs[0].job_id}...")
71-
logs = provider.get_job_logs("my-cluster", jobs[0].job_id, tail_lines=50)
72-
print(f" Logs (first 200 chars): {str(logs)[:200]}...\n")
73-
74-
# Stop cluster
75-
print("7. Stopping cluster...")
76-
stop_result = provider.stop_cluster("my-cluster")
77-
print(f" Result: {stop_result}\n")
47+
# # Get cluster resources
48+
# print("3. Getting cluster resources...")
49+
# resources = provider.get_cluster_resources("my-cluster")
50+
# print(f" Resources: {resources.num_nodes} nodes, GPUs: {resources.gpus}\n")
51+
52+
# # Submit a job
53+
# print("4. Submitting job...")
54+
# job_config = JobConfig(
55+
# command="python train.py",
56+
# job_name="training-job",
57+
# env_vars={"CUDA_VISIBLE_DEVICES": "0"},
58+
# )
59+
# job_result = provider.submit_job("my-cluster", job_config)
60+
# print(f" Job ID: {job_result.get('job_id')}\n")
61+
62+
# # List jobs
63+
# print("5. Listing jobs...")
64+
# jobs = provider.list_jobs("my-cluster")
65+
# for job in jobs:
66+
# print(f" Job {job.job_id}: {job.state} - {job.job_name}")
67+
68+
# # Get job logs
69+
# if jobs:
70+
# print(f"\n6. Getting logs for job {jobs[0].job_id}...")
71+
# logs = provider.get_job_logs("my-cluster", jobs[0].job_id, tail_lines=50)
72+
# print(f" Logs (first 200 chars): {str(logs)[:200]}...\n")
73+
74+
# # Stop cluster
75+
# print("7. Stopping cluster...")
76+
# stop_result = provider.stop_cluster("my-cluster")
77+
# print(f" Result: {stop_result}\n")
7878

7979

8080
def example_slurm():

src/lattice/providers/skypilot.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -349,45 +349,106 @@ def get_cluster_status(self, cluster_name: str) -> ClusterStatus:
349349
if self.default_entrypoint_command:
350350
body_json.setdefault("entrypoint_command", self.default_entrypoint_command)
351351
body_json.setdefault("using_remote_api_server", False)
352-
body_json.setdefault("override_skypilot_config", {})
353-
352+
body_json.setdefault("override_skypilot_config", {})
354353
# Use SkyPilot's make_authenticated_request (matches SDK exactly)
355354
response = self._make_authenticated_request(
356-
'POST', '/status', json_data=body_json, timeout=5
355+
'POST', '/status', json_data=body_json, timeout=10
357356
)
358357

358+
# Check response status
359+
if hasattr(response, 'status_code'):
360+
if response.status_code != 200:
361+
return ClusterStatus(
362+
cluster_name=cluster_name,
363+
state=ClusterState.UNKNOWN,
364+
status_message=f"API returned status code {response.status_code}",
365+
)
366+
367+
# Parse response content
368+
response_content = None
369+
is_null_response = False
370+
if hasattr(response, 'content'):
371+
if response.content == b'null' or response.content == b'':
372+
is_null_response = True
373+
else:
374+
try:
375+
response_content = response.json() if hasattr(response, 'json') else None
376+
except Exception as e:
377+
print(f"Warning: Could not parse response as JSON: {e}")
378+
response_content = None
379+
359380
# Get request ID using SkyPilot's method (matches SDK exactly)
360381
request_id = None
361382
if self._server_common:
362383
try:
363384
request_id = self._server_common.get_request_id(response)
364-
except Exception:
365-
pass
385+
# For debugging: return early after getting request_id
386+
print(f"Got request_id: {request_id}")
387+
return ClusterStatus(
388+
cluster_name=cluster_name,
389+
state=ClusterState.UNKNOWN,
390+
status_message=f"Debug: Got request_id {request_id}",
391+
)
392+
except Exception as e:
393+
# If get_request_id fails, try to extract from response directly
394+
if response_content and isinstance(response_content, dict):
395+
request_id = response_content.get("request_id")
396+
# Also check headers
397+
if not request_id and hasattr(response, 'headers'):
398+
request_id = (
399+
response.headers.get('X-Request-ID') or
400+
response.headers.get('Request-ID') or
401+
response.headers.get('X-Request-Id')
402+
)
403+
if not request_id:
404+
print(f"Warning: Could not extract request_id: {e}")
405+
print(f"Response status: {getattr(response, 'status_code', 'unknown')}")
406+
print(f"Response headers: {getattr(response, 'headers', {})}")
407+
print(f"Response content: {getattr(response, 'content', b'')[:200]}")
408+
409+
# If response is null and we don't have a request ID, the cluster likely doesn't exist
410+
if is_null_response and not request_id:
411+
return ClusterStatus(
412+
cluster_name=cluster_name,
413+
state=ClusterState.UNKNOWN,
414+
status_message="API returned null response - cluster may not exist or request format is incorrect",
415+
)
366416

367417
# Get the actual result from the request ID
418+
clusters = []
368419
if request_id and self._server_common:
369420
try:
370421
# Use server_common.get() to get the actual response
371-
clusters = self._server_common.get(request_id)
372-
except Exception:
422+
request_payload = self._server_common.get(request_id)
423+
if request_payload and hasattr(request_payload, 'return_value'):
424+
# return_value is a JSON string of the list of clusters
425+
if request_payload.return_value:
426+
clusters = json.loads(request_payload.return_value)
427+
elif request_payload and isinstance(request_payload, (list, dict)):
428+
# Sometimes the response is directly the clusters list
429+
clusters = request_payload if isinstance(request_payload, list) else [request_payload]
430+
except Exception as e:
431+
print(f"Warning: Could not get clusters from request_id {request_id}: {e}")
373432
# Fallback: try to parse response directly
374433
try:
375-
if hasattr(response, 'json'):
434+
if response_content:
435+
clusters = response_content if isinstance(response_content, list) else response_content.get("clusters", [])
436+
elif hasattr(response, 'json'):
376437
result = response.json()
377438
clusters = result if isinstance(result, list) else result.get("clusters", [])
378-
else:
379-
clusters = []
380-
except Exception:
439+
except Exception as parse_error:
440+
print(f"Warning: Could not parse response: {parse_error}")
381441
clusters = []
382442
else:
383443
# Fallback: try to parse response directly
384444
try:
385-
if hasattr(response, 'json'):
445+
if response_content:
446+
clusters = response_content if isinstance(response_content, list) else response_content.get("clusters", [])
447+
elif hasattr(response, 'json'):
386448
result = response.json()
387449
clusters = result if isinstance(result, list) else result.get("clusters", [])
388-
else:
389-
clusters = []
390-
except Exception:
450+
except Exception as e:
451+
print(f"Warning: Could not parse response directly: {e}")
391452
clusters = []
392453

393454
# Handle empty or invalid responses

0 commit comments

Comments
 (0)