Skip to content
This repository was archived by the owner on Dec 16, 2025. It is now read-only.

Commit b5f8440

Browse files
committed
revamp for bridge - add structure and get sky launch and down working
1 parent 5a69ee6 commit b5f8440

File tree

11 files changed

+1957
-7058
lines changed

11 files changed

+1957
-7058
lines changed

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ dependencies = [
1111
"fastapi==0.115.7",
1212
"pydantic>=2.10.4",
1313
"pyjwt==2.10.1",
14-
"python-dotenv==1.0.0",
14+
"python-dotenv>=1.0.0",
1515
"python-multipart>=0.0.18",
1616
"pyyaml>=6.0",
17+
"requests>=2.31.0",
18+
"paramiko>=3.0.0",
1719
"runpod>=1.6.0",
1820
"sqlalchemy>=2.0.42",
1921
"uvicorn[standard]==0.35.0",
@@ -50,7 +52,8 @@ dependencies = [
5052
"msrestazure==0.6.4.post1",
5153
"openapi-client>=1.1.7",
5254
"nanoid==2.0.0",
53-
"skypilot[all] @ git+https://github.com/transformerlab/skypilot.git@prod-2",
55+
# "skypilot[all] @ git+https://github.com/transformerlab/skypilot.git@prod-2",
56+
"skypilot[all]==0.10.5",
5457
"fsspec>=2025.9.0",
5558
"adlfs>=2022.11.1",
5659
"s3fs>=2025.9.0",

src/lattice/providers/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Provider bridge system for abstracting GPU orchestration providers."""
2+
3+
from .base import Provider
4+
from .router import ProviderRouter, get_provider
5+
from .config import load_providers_config, ProviderConfig
6+
7+
__all__ = [
8+
"Provider",
9+
"ProviderRouter",
10+
"get_provider",
11+
"load_providers_config",
12+
"ProviderConfig",
13+
]
14+

src/lattice/providers/base.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Abstract base class for provider implementations."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Dict, List, Any, Optional, Union
5+
from .models import (
6+
ClusterConfig,
7+
JobConfig,
8+
ClusterStatus,
9+
JobInfo,
10+
ResourceInfo,
11+
)
12+
13+
14+
class Provider(ABC):
15+
"""Abstract base class for all provider implementations."""
16+
17+
@abstractmethod
18+
def launch_cluster(
19+
self, cluster_name: str, config: ClusterConfig
20+
) -> Dict[str, Any]:
21+
"""
22+
Launch/provision a new cluster.
23+
24+
Args:
25+
cluster_name: Name of the cluster to launch
26+
config: Cluster configuration
27+
28+
Returns:
29+
Dictionary with launch result (e.g., request_id, cluster_name)
30+
"""
31+
pass
32+
33+
@abstractmethod
34+
def stop_cluster(self, cluster_name: str) -> Dict[str, Any]:
35+
"""
36+
Stop a running cluster (but don't tear it down).
37+
38+
Args:
39+
cluster_name: Name of the cluster to stop
40+
41+
Returns:
42+
Dictionary with stop result
43+
"""
44+
pass
45+
46+
@abstractmethod
47+
def get_cluster_status(
48+
self, cluster_name: str
49+
) -> ClusterStatus:
50+
"""
51+
Get the status of a cluster.
52+
53+
Args:
54+
cluster_name: Name of the cluster
55+
56+
Returns:
57+
ClusterStatus object with cluster information
58+
"""
59+
pass
60+
61+
@abstractmethod
62+
def get_cluster_resources(
63+
self, cluster_name: str
64+
) -> ResourceInfo:
65+
"""
66+
Get resource information for a cluster (GPUs, CPUs, memory, etc.).
67+
68+
Args:
69+
cluster_name: Name of the cluster
70+
71+
Returns:
72+
ResourceInfo object with resource details
73+
"""
74+
pass
75+
76+
@abstractmethod
77+
def submit_job(
78+
self, cluster_name: str, job_config: JobConfig
79+
) -> Dict[str, Any]:
80+
"""
81+
Submit a job to an existing cluster.
82+
83+
Args:
84+
cluster_name: Name of the cluster
85+
job_config: Job configuration
86+
87+
Returns:
88+
Dictionary with job submission result (e.g., job_id)
89+
"""
90+
pass
91+
92+
@abstractmethod
93+
def get_job_logs(
94+
self,
95+
cluster_name: str,
96+
job_id: Union[str, int],
97+
tail_lines: Optional[int] = None,
98+
follow: bool = False,
99+
) -> Union[str, Any]:
100+
"""
101+
Get logs for a job.
102+
103+
Args:
104+
cluster_name: Name of the cluster
105+
job_id: Job identifier
106+
tail_lines: Number of lines to retrieve from the end (None for all)
107+
follow: Whether to stream/follow logs (returns stream if True)
108+
109+
Returns:
110+
Log content as string, or stream object if follow=True
111+
"""
112+
pass
113+
114+
@abstractmethod
115+
def cancel_job(
116+
self, cluster_name: str, job_id: Union[str, int]
117+
) -> Dict[str, Any]:
118+
"""
119+
Cancel a running or queued job.
120+
121+
Args:
122+
cluster_name: Name of the cluster
123+
job_id: Job identifier
124+
125+
Returns:
126+
Dictionary with cancellation result
127+
"""
128+
pass
129+
130+
@abstractmethod
131+
def list_jobs(self, cluster_name: str) -> List[JobInfo]:
132+
"""
133+
List all jobs for a cluster.
134+
135+
Args:
136+
cluster_name: Name of the cluster
137+
138+
Returns:
139+
List of JobInfo objects
140+
"""
141+
pass
142+

src/lattice/providers/config.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""Configuration loading and provider factory."""
2+
3+
import os
4+
import yaml
5+
import json
6+
from typing import Dict, Any, Optional
7+
from pathlib import Path
8+
from pydantic import BaseModel, Field
9+
10+
11+
class ProviderConfig(BaseModel):
12+
"""Configuration for a single provider."""
13+
14+
type: str # "skypilot" or "slurm"
15+
name: str # Provider name/identifier
16+
17+
# SkyPilot-specific config
18+
server_url: Optional[str] = None
19+
api_token: Optional[str] = None
20+
default_env_vars: Dict[str, str] = Field(default_factory=dict)
21+
default_entrypoint_command: Optional[str] = None
22+
23+
# SLURM-specific config
24+
mode: Optional[str] = None # "rest" or "ssh"
25+
rest_url: Optional[str] = None
26+
ssh_host: Optional[str] = None
27+
ssh_user: Optional[str] = None
28+
ssh_key_path: Optional[str] = None
29+
ssh_port: int = 22
30+
31+
# Additional provider-specific config
32+
extra_config: Dict[str, Any] = Field(default_factory=dict)
33+
34+
35+
def load_providers_config(
36+
config_path: Optional[str] = None,
37+
) -> Dict[str, ProviderConfig]:
38+
"""
39+
Load provider configurations from YAML or JSON file.
40+
41+
Args:
42+
config_path: Path to config file. If None, uses default location
43+
or PROVIDERS_CONFIG_PATH env var.
44+
45+
Returns:
46+
Dictionary mapping provider names to ProviderConfig objects
47+
"""
48+
if config_path is None:
49+
# Check environment variable first
50+
env_path = os.getenv("PROVIDERS_CONFIG_PATH")
51+
if env_path:
52+
config_path = env_path
53+
else:
54+
# Try to find the config file in multiple locations
55+
current_file = Path(__file__).resolve()
56+
57+
# 1. Check in the same directory as this file (installed package)
58+
package_config = current_file.parent / "providers.yaml"
59+
60+
# 2. Check in source directory (when running from repo)
61+
# Go up from src/lattice/providers/config.py to find repo root
62+
# Then look for src/lattice/providers/providers.yaml
63+
source_config = None
64+
for parent in [
65+
current_file.parent.parent.parent.parent,
66+
current_file.parent.parent.parent.parent.parent,
67+
]:
68+
potential = parent / "src" / "lattice" / "providers" / "providers.yaml"
69+
if potential.exists():
70+
source_config = potential
71+
break
72+
73+
# Prefer source config if it exists (for development)
74+
if source_config and source_config.exists():
75+
config_path = str(source_config)
76+
elif package_config.exists():
77+
config_path = str(package_config)
78+
else:
79+
# Default to package directory location
80+
config_path = str(package_config)
81+
82+
config_path = Path(config_path).expanduser().resolve()
83+
84+
if not config_path.exists():
85+
# Provide helpful error message with suggestions
86+
suggested_paths = [
87+
os.path.join(os.path.dirname(__file__), "providers.yaml"),
88+
os.path.join(os.getcwd(), "providers.yaml"),
89+
]
90+
suggestions = "\n".join(f" - {p}" for p in suggested_paths if Path(p).exists())
91+
error_msg = (
92+
f"Provider config file not found: {config_path}\n"
93+
f"Please create a providers.yaml file or set PROVIDERS_CONFIG_PATH environment variable.\n"
94+
)
95+
if suggestions:
96+
error_msg += f"Found config files at:\n{suggestions}\n"
97+
raise FileNotFoundError(error_msg)
98+
99+
with open(config_path, "r") as f:
100+
if config_path.suffix in [".yaml", ".yml"]:
101+
config_data = yaml.safe_load(f)
102+
elif config_path.suffix == ".json":
103+
config_data = json.load(f)
104+
else:
105+
raise ValueError(f"Unsupported config file format: {config_path.suffix}")
106+
107+
providers = {}
108+
providers_data = config_data.get("providers", {})
109+
110+
for name, provider_data in providers_data.items():
111+
provider_data["name"] = name
112+
providers[name] = ProviderConfig(**provider_data)
113+
114+
return providers
115+
116+
117+
def create_provider(config: ProviderConfig):
118+
"""
119+
Factory function to create a provider instance from config.
120+
121+
Args:
122+
config: ProviderConfig object
123+
124+
Returns:
125+
Provider instance
126+
"""
127+
from .skypilot import SkyPilotProvider
128+
from .slurm import SLURMProvider
129+
130+
if config.type == "skypilot":
131+
if not config.server_url:
132+
raise ValueError("SkyPilot provider requires server_url in config")
133+
return SkyPilotProvider(
134+
server_url=config.server_url,
135+
api_token=config.api_token,
136+
default_env_vars=config.default_env_vars,
137+
default_entrypoint_command=config.default_entrypoint_command,
138+
extra_config=config.extra_config,
139+
)
140+
elif config.type == "slurm":
141+
if config.mode == "rest":
142+
if not config.rest_url:
143+
raise ValueError(
144+
"SLURM provider in REST mode requires rest_url in config"
145+
)
146+
return SLURMProvider(
147+
mode="rest",
148+
rest_url=config.rest_url,
149+
api_token=config.api_token,
150+
extra_config=config.extra_config,
151+
)
152+
elif config.mode == "ssh":
153+
if not config.ssh_host:
154+
raise ValueError(
155+
"SLURM provider in SSH mode requires ssh_host in config"
156+
)
157+
return SLURMProvider(
158+
mode="ssh",
159+
ssh_host=config.ssh_host,
160+
ssh_user=config.ssh_user or os.getenv("USER", "root"),
161+
ssh_key_path=config.ssh_key_path,
162+
ssh_port=config.ssh_port,
163+
extra_config=config.extra_config,
164+
)
165+
else:
166+
raise ValueError(
167+
f"SLURM provider mode must be 'rest' or 'ssh', got: {config.mode}"
168+
)
169+
else:
170+
raise ValueError(f"Unknown provider type: {config.type}")

0 commit comments

Comments
 (0)