Skip to content
This repository was archived by the owner on Dec 16, 2025. It is now read-only.

Commit 6dc7c05

Browse files
authored
Merge branch 'main' into add/use-fsspec
2 parents f7bf2c0 + 2fa87e3 commit 6dc7c05

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

src/lattice/routes/jobs/routes.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import queue
1515
import threading
1616
from werkzeug.utils import secure_filename
17+
from sqlalchemy.orm import Session
18+
from config import get_db
1719
from models import (
1820
JobQueueResponse,
1921
JobLogsResponse,
@@ -400,9 +402,11 @@ async def submit_job_to_cluster(
400402
dir_name: Optional[str] = Form(None),
401403
uploaded_dir_path: Optional[str] = Form(None),
402404
num_nodes: Optional[int] = Form(None),
405+
tlab_job_id: Optional[str] = Form(None),
403406
yaml_file: Optional[UploadFile] = File(None),
404407
user: dict = Depends(get_user_or_api_key),
405408
scope_check: dict = Depends(require_scope("compute:write")),
409+
db: Session = Depends(get_db),
406410
):
407411
try:
408412
# Parse YAML configuration if provided
@@ -445,6 +449,7 @@ async def submit_job_to_cluster(
445449
"job_name": job_name,
446450
"dir_name": dir_name,
447451
"num_nodes": num_nodes,
452+
"tlab_job_id": tlab_job_id,
448453
}
449454

450455
# Override with YAML values where form parameters are None (excluding resource fields)
@@ -478,6 +483,7 @@ async def submit_job_to_cluster(
478483
job_name = final_config["job_name"]
479484
dir_name = final_config["dir_name"]
480485
num_nodes = final_config["num_nodes"]
486+
tlab_job_id = final_config["tlab_job_id"]
481487

482488
file_mounts = None
483489

@@ -503,8 +509,96 @@ async def submit_job_to_cluster(
503509
# Mount the entire directory at ~/<base_name>
504510
file_mounts = {f"~/{base_name}": uploaded_dir_path}
505511

512+
# Initialize hook environment variables
513+
hook_env_vars = {}
514+
515+
# Handle launch hooks for the organization
516+
organization_id = user.get("organization_id")
517+
if organization_id:
518+
from db.db_models import LaunchHook, LaunchHookFile, TeamMembership
519+
520+
# Get user's team ID
521+
user_team = (
522+
db.query(TeamMembership)
523+
.filter(
524+
TeamMembership.organization_id == organization_id,
525+
TeamMembership.user_id == user.get("id"),
526+
)
527+
.first()
528+
)
529+
user_team_id = user_team.team_id if user_team else None
530+
531+
# Get all active launch hooks for the organization
532+
active_hooks = (
533+
db.query(LaunchHook)
534+
.filter(
535+
LaunchHook.organization_id == organization_id,
536+
LaunchHook.is_active == True, # noqa: E712
537+
)
538+
.all()
539+
)
540+
541+
# Filter hooks based on team access
542+
accessible_hooks = []
543+
for hook in active_hooks:
544+
# If no team restrictions (allowed_team_ids is None), hook is accessible to all
545+
if hook.allowed_team_ids is None:
546+
accessible_hooks.append(hook)
547+
# If user has no team, they can't access team-restricted hooks
548+
elif user_team_id is None:
549+
continue
550+
# If user's team is in the allowed list, they can access the hook
551+
elif user_team_id in hook.allowed_team_ids:
552+
accessible_hooks.append(hook)
553+
554+
if accessible_hooks:
555+
# Initialize file_mounts if not already set
556+
if file_mounts is None:
557+
file_mounts = {}
558+
559+
# Collect all setup commands and environment variables from accessible hooks
560+
hook_setup_commands = []
561+
562+
for hook in accessible_hooks:
563+
# Add setup commands from this hook
564+
if hook.setup_commands:
565+
hook_setup_commands.append(hook.setup_commands)
566+
567+
# Collect environment variables from this hook
568+
if hook.env_vars and isinstance(hook.env_vars, dict):
569+
hook_env_vars.update(hook.env_vars)
570+
571+
# Get files for this hook
572+
hook_files = (
573+
db.query(LaunchHookFile)
574+
.filter(
575+
LaunchHookFile.launch_hook_id == hook.id,
576+
LaunchHookFile.is_active == True, # noqa: E712
577+
)
578+
.all()
579+
)
580+
581+
# Mount each file to ~/hooks/<filename>
582+
for hook_file in hook_files:
583+
if os.path.exists(hook_file.file_path):
584+
mount_path = f"~/hooks/{hook_file.original_filename}"
585+
file_mounts[mount_path] = hook_file.file_path
586+
587+
# Prepend hook setup commands to the main setup commands
588+
if hook_setup_commands:
589+
combined_setup = ";".join(hook_setup_commands)
590+
if setup:
591+
setup = f"{combined_setup};{setup}"
592+
else:
593+
setup = combined_setup
594+
595+
# Set _TFL_JOB_ID environment variable if tlab_job_id is provided
596+
if tlab_job_id:
597+
hook_env_vars["_TFL_JOB_ID"] = tlab_job_id
598+
506599
command = command.replace("\r", "")
507-
setup = setup.replace("\r", "")
600+
if setup:
601+
setup = setup.replace("\r", "")
508602

509603
# Apply secure_filename to job_name if provided
510604
secure_job_name = None
@@ -516,6 +610,42 @@ async def submit_job_to_cluster(
516610
cluster_name, user["id"], user["organization_id"]
517611
)
518612

613+
# Handle mandatory storage mounts (skip for RunPod clusters)
614+
storage_mounts = {}
615+
platform_info = get_cluster_platform_info_util(actual_cluster_name)
616+
is_runpod = False
617+
if platform_info and platform_info.get("platform"):
618+
platform = platform_info["platform"]
619+
if platform == "multi-cloud":
620+
from routes.instances.utils import (
621+
determine_actual_cloud_from_skypilot_status,
622+
)
623+
624+
# Determine the actual cloud used by SkyPilot
625+
actual_platform = determine_actual_cloud_from_skypilot_status(
626+
actual_cluster_name
627+
)
628+
platform = actual_platform if actual_platform else platform
629+
is_runpod = platform == "runpod"
630+
631+
if (
632+
not is_runpod
633+
and os.getenv("TRANSFORMERLAB_BUCKET_NAME")
634+
and os.getenv("TRANSFORMERLAB_BUCKET_SOURCE")
635+
):
636+
import sky
637+
638+
transformerlab_bucket = sky.Storage(
639+
name=os.getenv("TRANSFORMERLAB_BUCKET_NAME"),
640+
mode=sky.StorageMode.MOUNT,
641+
source=os.getenv("TRANSFORMERLAB_BUCKET_SOURCE"),
642+
persistent=True,
643+
)
644+
storage_mounts["/workspace"] = transformerlab_bucket
645+
print(
646+
f"[Jobs] Added mandatory transformerlab bucket: {transformerlab_bucket}"
647+
)
648+
519649
# Default num_nodes to 1 if not provided or invalid
520650
try:
521651
if num_nodes is None:
@@ -539,6 +669,8 @@ async def submit_job_to_cluster(
539669
zone=zone,
540670
job_name=secure_job_name,
541671
num_nodes=effective_num_nodes,
672+
env_vars=hook_env_vars,
673+
storage_mounts=storage_mounts,
542674
)
543675

544676
# Record usage event

src/lattice/routes/jobs/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def submit_job_to_existing_cluster(
133133
zone: Optional[str] = None,
134134
job_name: Optional[str] = None,
135135
num_nodes: Optional[int] = None,
136+
env_vars: Optional[dict] = None,
137+
storage_mounts: Optional[dict] = None,
136138
):
137139
try:
138140
# Create job name with metadata if it's a special job type
@@ -148,14 +150,24 @@ def submit_job_to_existing_cluster(
148150
except Exception:
149151
effective_num_nodes = 1
150152

153+
# Prepare environment variables
154+
envs = None
155+
if env_vars and isinstance(env_vars, dict):
156+
envs = env_vars.copy()
157+
151158
task = sky.Task(
152159
name=final_job_name,
153160
run=command,
154161
setup=setup,
155162
num_nodes=effective_num_nodes,
163+
envs=envs,
156164
)
157165
if file_mounts:
158166
task.set_file_mounts(file_mounts)
167+
168+
# Set storage mounts if provided
169+
if storage_mounts:
170+
task.set_storage_mounts(storage_mounts)
159171

160172
resources_kwargs = {}
161173
if cpus:

0 commit comments

Comments
 (0)