Skip to content

Commit af7027e

Browse files
authored
Merge pull request #1551 from transformerlab/claude/skypilot-container-investigation-YdrpB
Add SkyPilot container and cloud settings (global + per-job)
2 parents 0294994 + c6ffc11 commit af7027e

File tree

5 files changed

+321
-89
lines changed

5 files changed

+321
-89
lines changed

api/transformerlab/compute_providers/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class ClusterConfig(BaseModel):
4747
zone: Optional[str] = None
4848
use_spot: bool = False
4949

50+
# Container / VM image override (SkyPilot only).
51+
# Use "docker:<image>" for Docker containers, e.g. "docker:nvcr.io/nvidia/pytorch:23.10-py3".
52+
image_id: Optional[str] = None
53+
5054
# Cluster settings
5155
idle_minutes_to_autostop: Optional[int] = None
5256
run: Optional[str] = None # Initial run command

api/transformerlab/compute_providers/skypilot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def launch_cluster(self, cluster_name: str, config: ClusterConfig) -> Dict[str,
341341
resources_kwargs["zone"] = config.zone
342342
if config.use_spot:
343343
resources_kwargs["use_spot"] = True
344+
if config.image_id:
345+
resources_kwargs["image_id"] = config.image_id
344346

345347
if resources_kwargs:
346348
task.set_resources(sky_resources.Resources(**resources_kwargs))

api/transformerlab/routers/compute_provider.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,26 @@ async def _launch_sweep_jobs(
14471447

14481448
# When file_mounts is True we use lab.copy_file_mounts() in setup; do not send to provider
14491449
file_mounts_for_provider = request.file_mounts if isinstance(request.file_mounts, dict) else {}
1450+
1451+
# Resolve SkyPilot-specific settings from provider config for sweep child jobs
1452+
sweep_image_id: str | None = None
1453+
sweep_region: str | None = None
1454+
sweep_zone: str | None = None
1455+
sweep_use_spot: bool = False
1456+
if provider.type == ProviderType.SKYPILOT.value:
1457+
prov_cfg = provider.config or {}
1458+
sweep_image_id = prov_cfg.get("docker_image") or None
1459+
sweep_region = prov_cfg.get("default_region") or None
1460+
sweep_zone = prov_cfg.get("default_zone") or None
1461+
sweep_use_spot = prov_cfg.get("use_spot", False) is True
1462+
if request.config:
1463+
if request.config.get("docker_image"):
1464+
sweep_image_id = str(request.config["docker_image"]).strip()
1465+
if request.config.get("region"):
1466+
sweep_region = str(request.config["region"]).strip()
1467+
if request.config.get("use_spot"):
1468+
sweep_use_spot = True
1469+
14501470
cluster_config = ClusterConfig(
14511471
cluster_name=formatted_cluster_name,
14521472
provider_name=provider_display_name,
@@ -1461,6 +1481,10 @@ async def _launch_sweep_jobs(
14611481
disk_size=disk_size,
14621482
file_mounts=file_mounts_for_provider,
14631483
provider_config={"requested_disk_space": request.disk_space},
1484+
image_id=sweep_image_id,
1485+
region=sweep_region,
1486+
zone=sweep_zone,
1487+
use_spot=sweep_use_spot,
14641488
)
14651489

14661490
# Launch cluster for child job
@@ -1958,6 +1982,28 @@ async def launch_template_on_provider(
19581982
else:
19591983
parameters_with_secrets = merged_parameters if merged_parameters else None
19601984

1985+
# For SkyPilot providers, resolve docker_image / region / use_spot.
1986+
# Per-job overrides (from request.config) take precedence over provider-level defaults.
1987+
skypilot_image_id: str | None = None
1988+
skypilot_region: str | None = None
1989+
skypilot_zone: str | None = None
1990+
skypilot_use_spot: bool = False
1991+
if provider.type == ProviderType.SKYPILOT.value:
1992+
prov_cfg = provider.config or {}
1993+
# Provider-level defaults
1994+
skypilot_image_id = prov_cfg.get("docker_image") or None
1995+
skypilot_region = prov_cfg.get("default_region") or None
1996+
skypilot_zone = prov_cfg.get("default_zone") or None
1997+
skypilot_use_spot = prov_cfg.get("use_spot", False) is True
1998+
# Per-job overrides from the frontend config dict
1999+
if request.config:
2000+
if request.config.get("docker_image"):
2001+
skypilot_image_id = str(request.config["docker_image"]).strip()
2002+
if request.config.get("region"):
2003+
skypilot_region = str(request.config["region"]).strip()
2004+
if request.config.get("use_spot"):
2005+
skypilot_use_spot = True
2006+
19612007
# Build provider_config for cluster_config (and job_data for local provider)
19622008
provider_config_dict = {"requested_disk_space": request.disk_space}
19632009
# For SLURM, pass through any per-run custom SBATCH flags so the provider
@@ -2094,6 +2140,10 @@ async def launch_template_on_provider(
20942140
disk_size=disk_size,
20952141
file_mounts=file_mounts_for_provider,
20962142
provider_config=provider_config_dict,
2143+
image_id=skypilot_image_id,
2144+
region=skypilot_region,
2145+
zone=skypilot_zone,
2146+
use_spot=skypilot_use_spot,
20972147
)
20982148

20992149
await job_service.job_update_launch_progress(

src/renderer/components/Experiment/Tasks/QueueTaskModal.tsx

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ export default function QueueTaskModal({
112112
const [sweepMetric, setSweepMetric] = React.useState('eval/loss');
113113
const [lowerIsBetter, setLowerIsBetter] = React.useState(true);
114114
const [jobSlurmFlags, setJobSlurmFlags] = React.useState<string[]>(['']);
115+
const [jobDockerImage, setJobDockerImage] = React.useState('');
116+
const [jobRegion, setJobRegion] = React.useState('');
117+
const [jobUseSpot, setJobUseSpot] = React.useState(false);
115118
const [useTrackio, setUseTrackio] = React.useState(false);
116119
const [useProfiling, setUseProfiling] = React.useState(false);
117120
const [useProfilingTorch, setUseProfilingTorch] = React.useState(false);
@@ -215,6 +218,7 @@ export default function QueueTaskModal({
215218
);
216219
const isLocalProvider = selectedProvider?.type === 'local';
217220
const isSlurmProvider = selectedProvider?.type === 'slurm';
221+
const isSkypilotProvider = selectedProvider?.type === 'skypilot';
218222

219223
// Fetch user-specific provider settings (including default custom SBATCH flags)
220224
const slurmUserSettingsKey =
@@ -632,6 +636,15 @@ export default function QueueTaskModal({
632636
setJobSlurmFlags(lines.length > 0 ? lines : ['']);
633637
}, [open, isSlurmProvider, selectedProviderId, slurmUserSettings]);
634638

639+
// Initialize SkyPilot per-job defaults from provider config when a SkyPilot provider is selected.
640+
React.useEffect(() => {
641+
if (!open || !isSkypilotProvider || !selectedProvider) return;
642+
const cfg = selectedProvider.config || {};
643+
setJobDockerImage(cfg.docker_image || '');
644+
setJobRegion(cfg.default_region || '');
645+
setJobUseSpot(cfg.use_spot === true);
646+
}, [open, isSkypilotProvider, selectedProviderId, selectedProvider]);
647+
635648
// Helper function to validate constraints
636649
const validateParameter = (param: ProcessedParameter): string | null => {
637650
const { schema, value } = param;
@@ -744,6 +757,19 @@ export default function QueueTaskModal({
744757
}
745758
}
746759

760+
// For SkyPilot providers, add optional per-job overrides
761+
if (provider?.type === 'skypilot') {
762+
if (jobDockerImage.trim()) {
763+
config.docker_image = jobDockerImage.trim();
764+
}
765+
if (jobRegion.trim()) {
766+
config.region = jobRegion.trim();
767+
}
768+
if (jobUseSpot) {
769+
config.use_spot = true;
770+
}
771+
}
772+
747773
// Add sweep configuration if enabled
748774
if (runSweeps) {
749775
config.run_sweeps = true;
@@ -1580,6 +1606,53 @@ export default function QueueTaskModal({
15801606
the template default.
15811607
</FormHelperText>
15821608

1609+
{/* SkyPilot per-job overrides */}
1610+
{isSkypilotProvider && (
1611+
<>
1612+
<Divider />
1613+
<Typography level="title-sm">
1614+
SkyPilot Job Overrides
1615+
</Typography>
1616+
<FormControl>
1617+
<FormLabel>Docker Image (optional)</FormLabel>
1618+
<Input
1619+
value={jobDockerImage}
1620+
onChange={(e) => setJobDockerImage(e.target.value)}
1621+
placeholder="docker:nvcr.io/nvidia/pytorch:23.10-py3"
1622+
sx={{ fontFamily: 'monospace', fontSize: 'sm' }}
1623+
disabled={isSubmitting}
1624+
/>
1625+
<FormHelperText>
1626+
Prefix with &quot;docker:&quot; to run inside a
1627+
container. Defaults to the provider&apos;s global
1628+
setting.
1629+
</FormHelperText>
1630+
</FormControl>
1631+
<FormControl>
1632+
<FormLabel>Region (optional)</FormLabel>
1633+
<Input
1634+
value={jobRegion}
1635+
onChange={(e) => setJobRegion(e.target.value)}
1636+
placeholder="e.g. us-east-1"
1637+
disabled={isSubmitting}
1638+
/>
1639+
</FormControl>
1640+
<FormControl
1641+
sx={{ flexDirection: 'row', alignItems: 'center' }}
1642+
>
1643+
<Switch
1644+
checked={jobUseSpot}
1645+
onChange={(e) => setJobUseSpot(e.target.checked)}
1646+
disabled={isSubmitting}
1647+
sx={{ mr: 1 }}
1648+
/>
1649+
<FormLabel sx={{ m: 0 }}>
1650+
Use Spot / Preemptible Instances
1651+
</FormLabel>
1652+
</FormControl>
1653+
</>
1654+
)}
1655+
15831656
{/* Incompatibility Warning */}
15841657
{selectedProvider &&
15851658
effectiveResources?.accelerators &&

0 commit comments

Comments
 (0)