Skip to content
This repository was archived by the owner on Dec 16, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion src/lattice/routes/jobs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import queue
import threading
from werkzeug.utils import secure_filename
from sqlalchemy.orm import Session
from config import get_db
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I forgot that we never separated out our DB code in lattice. Not in scope for this PR though as we access DB all over the place. I will try to figure out what list to add that to!

from models import (
JobQueueResponse,
JobLogsResponse,
Expand Down Expand Up @@ -400,9 +402,11 @@ async def submit_job_to_cluster(
dir_name: Optional[str] = Form(None),
uploaded_dir_path: Optional[str] = Form(None),
num_nodes: Optional[int] = Form(None),
tlab_job_id: Optional[str] = Form(None),
yaml_file: Optional[UploadFile] = File(None),
user: dict = Depends(get_user_or_api_key),
scope_check: dict = Depends(require_scope("compute:write")),
db: Session = Depends(get_db),
):
try:
# Parse YAML configuration if provided
Expand Down Expand Up @@ -445,6 +449,7 @@ async def submit_job_to_cluster(
"job_name": job_name,
"dir_name": dir_name,
"num_nodes": num_nodes,
"tlab_job_id": tlab_job_id,
}

# Override with YAML values where form parameters are None (excluding resource fields)
Expand Down Expand Up @@ -478,6 +483,7 @@ async def submit_job_to_cluster(
job_name = final_config["job_name"]
dir_name = final_config["dir_name"]
num_nodes = final_config["num_nodes"]
tlab_job_id = final_config["tlab_job_id"]

file_mounts = None

Expand All @@ -503,8 +509,96 @@ async def submit_job_to_cluster(
# Mount the entire directory at ~/<base_name>
file_mounts = {f"~/{base_name}": uploaded_dir_path}

# Initialize hook environment variables
hook_env_vars = {}

# Handle launch hooks for the organization
organization_id = user.get("organization_id")
if organization_id:
from db.db_models import LaunchHook, LaunchHookFile, TeamMembership

# Get user's team ID
user_team = (
db.query(TeamMembership)
.filter(
TeamMembership.organization_id == organization_id,
TeamMembership.user_id == user.get("id"),
)
.first()
)
user_team_id = user_team.team_id if user_team else None

# Get all active launch hooks for the organization
active_hooks = (
db.query(LaunchHook)
.filter(
LaunchHook.organization_id == organization_id,
LaunchHook.is_active == True, # noqa: E712
)
.all()
)

# Filter hooks based on team access
accessible_hooks = []
for hook in active_hooks:
# If no team restrictions (allowed_team_ids is None), hook is accessible to all
if hook.allowed_team_ids is None:
accessible_hooks.append(hook)
# If user has no team, they can't access team-restricted hooks
elif user_team_id is None:
continue
# If user's team is in the allowed list, they can access the hook
elif user_team_id in hook.allowed_team_ids:
accessible_hooks.append(hook)

if accessible_hooks:
# Initialize file_mounts if not already set
if file_mounts is None:
file_mounts = {}

# Collect all setup commands and environment variables from accessible hooks
hook_setup_commands = []

for hook in accessible_hooks:
# Add setup commands from this hook
if hook.setup_commands:
hook_setup_commands.append(hook.setup_commands)

# Collect environment variables from this hook
if hook.env_vars and isinstance(hook.env_vars, dict):
hook_env_vars.update(hook.env_vars)

# Get files for this hook
hook_files = (
db.query(LaunchHookFile)
.filter(
LaunchHookFile.launch_hook_id == hook.id,
LaunchHookFile.is_active == True, # noqa: E712
)
.all()
)

# Mount each file to ~/hooks/<filename>
for hook_file in hook_files:
if os.path.exists(hook_file.file_path):
mount_path = f"~/hooks/{hook_file.original_filename}"
file_mounts[mount_path] = hook_file.file_path

# Prepend hook setup commands to the main setup commands
if hook_setup_commands:
combined_setup = ";".join(hook_setup_commands)
if setup:
setup = f"{combined_setup};{setup}"
else:
setup = combined_setup

# Set _TFL_JOB_ID environment variable if tlab_job_id is provided
if tlab_job_id:
hook_env_vars["_TFL_JOB_ID"] = tlab_job_id

command = command.replace("\r", "")
setup = setup.replace("\r", "")
if setup:
setup = setup.replace("\r", "")

# Apply secure_filename to job_name if provided
secure_job_name = None
Expand All @@ -516,6 +610,42 @@ async def submit_job_to_cluster(
cluster_name, user["id"], user["organization_id"]
)

# Handle mandatory storage mounts (skip for RunPod clusters)
storage_mounts = {}
platform_info = get_cluster_platform_info_util(actual_cluster_name)
is_runpod = False
if platform_info and platform_info.get("platform"):
platform = platform_info["platform"]
if platform == "multi-cloud":
from routes.instances.utils import (
determine_actual_cloud_from_skypilot_status,
)

# Determine the actual cloud used by SkyPilot
actual_platform = determine_actual_cloud_from_skypilot_status(
actual_cluster_name
)
platform = actual_platform if actual_platform else platform
is_runpod = platform == "runpod"

if (
not is_runpod
and os.getenv("TRANSFORMERLAB_BUCKET_NAME")
and os.getenv("TRANSFORMERLAB_BUCKET_SOURCE")
):
import sky

transformerlab_bucket = sky.Storage(
name=os.getenv("TRANSFORMERLAB_BUCKET_NAME"),
mode=sky.StorageMode.MOUNT,
source=os.getenv("TRANSFORMERLAB_BUCKET_SOURCE"),
persistent=True,
)
storage_mounts["/workspace"] = transformerlab_bucket
print(
f"[Jobs] Added mandatory transformerlab bucket: {transformerlab_bucket}"
)

# Default num_nodes to 1 if not provided or invalid
try:
if num_nodes is None:
Expand All @@ -539,6 +669,8 @@ async def submit_job_to_cluster(
zone=zone,
job_name=secure_job_name,
num_nodes=effective_num_nodes,
env_vars=hook_env_vars,
storage_mounts=storage_mounts,
)

# Record usage event
Expand Down
12 changes: 12 additions & 0 deletions src/lattice/routes/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def submit_job_to_existing_cluster(
zone: Optional[str] = None,
job_name: Optional[str] = None,
num_nodes: Optional[int] = None,
env_vars: Optional[dict] = None,
storage_mounts: Optional[dict] = None,
):
try:
# Create job name with metadata if it's a special job type
Expand All @@ -148,14 +150,24 @@ def submit_job_to_existing_cluster(
except Exception:
effective_num_nodes = 1

# Prepare environment variables
envs = None
if env_vars and isinstance(env_vars, dict):
envs = env_vars.copy()

task = sky.Task(
name=final_job_name,
run=command,
setup=setup,
num_nodes=effective_num_nodes,
envs=envs,
)
if file_mounts:
task.set_file_mounts(file_mounts)

# Set storage mounts if provided
if storage_mounts:
task.set_storage_mounts(storage_mounts)

resources_kwargs = {}
if cpus:
Expand Down