diff --git a/src/lattice/routes/jobs/routes.py b/src/lattice/routes/jobs/routes.py index 1a7a359..28252dd 100644 --- a/src/lattice/routes/jobs/routes.py +++ b/src/lattice/routes/jobs/routes.py @@ -14,6 +14,8 @@ import queue import threading from werkzeug.utils import secure_filename +from sqlalchemy.orm import Session +from config import get_db from models import ( JobQueueResponse, JobLogsResponse, @@ -400,9 +402,11 @@ async def submit_job_to_cluster( dir_name: Optional[str] = Form(None), uploaded_dir_path: Optional[str] = Form(None), num_nodes: Optional[int] = Form(None), + tlab_job_id: Optional[str] = Form(None), yaml_file: Optional[UploadFile] = File(None), user: dict = Depends(get_user_or_api_key), scope_check: dict = Depends(require_scope("compute:write")), + db: Session = Depends(get_db), ): try: # Parse YAML configuration if provided @@ -445,6 +449,7 @@ async def submit_job_to_cluster( "job_name": job_name, "dir_name": dir_name, "num_nodes": num_nodes, + "tlab_job_id": tlab_job_id, } # Override with YAML values where form parameters are None (excluding resource fields) @@ -478,6 +483,7 @@ async def submit_job_to_cluster( job_name = final_config["job_name"] dir_name = final_config["dir_name"] num_nodes = final_config["num_nodes"] + tlab_job_id = final_config["tlab_job_id"] file_mounts = None @@ -503,8 +509,96 @@ async def submit_job_to_cluster( # Mount the entire directory at ~/ file_mounts = {f"~/{base_name}": uploaded_dir_path} + # Initialize hook environment variables + hook_env_vars = {} + + # Handle launch hooks for the organization + organization_id = user.get("organization_id") + if organization_id: + from db.db_models import LaunchHook, LaunchHookFile, TeamMembership + + # Get user's team ID + user_team = ( + db.query(TeamMembership) + .filter( + TeamMembership.organization_id == organization_id, + TeamMembership.user_id == user.get("id"), + ) + .first() + ) + user_team_id = user_team.team_id if user_team else None + + # Get all active launch hooks for the organization + active_hooks = ( + db.query(LaunchHook) + .filter( + LaunchHook.organization_id == organization_id, + LaunchHook.is_active == True, # noqa: E712 + ) + .all() + ) + + # Filter hooks based on team access + accessible_hooks = [] + for hook in active_hooks: + # If no team restrictions (allowed_team_ids is None), hook is accessible to all + if hook.allowed_team_ids is None: + accessible_hooks.append(hook) + # If user has no team, they can't access team-restricted hooks + elif user_team_id is None: + continue + # If user's team is in the allowed list, they can access the hook + elif user_team_id in hook.allowed_team_ids: + accessible_hooks.append(hook) + + if accessible_hooks: + # Initialize file_mounts if not already set + if file_mounts is None: + file_mounts = {} + + # Collect all setup commands and environment variables from accessible hooks + hook_setup_commands = [] + + for hook in accessible_hooks: + # Add setup commands from this hook + if hook.setup_commands: + hook_setup_commands.append(hook.setup_commands) + + # Collect environment variables from this hook + if hook.env_vars and isinstance(hook.env_vars, dict): + hook_env_vars.update(hook.env_vars) + + # Get files for this hook + hook_files = ( + db.query(LaunchHookFile) + .filter( + LaunchHookFile.launch_hook_id == hook.id, + LaunchHookFile.is_active == True, # noqa: E712 + ) + .all() + ) + + # Mount each file to ~/hooks/ + for hook_file in hook_files: + if os.path.exists(hook_file.file_path): + mount_path = f"~/hooks/{hook_file.original_filename}" + file_mounts[mount_path] = hook_file.file_path + + # Prepend hook setup commands to the main setup commands + if hook_setup_commands: + combined_setup = ";".join(hook_setup_commands) + if setup: + setup = f"{combined_setup};{setup}" + else: + setup = combined_setup + + # Set _TFL_JOB_ID environment variable if tlab_job_id is provided + if tlab_job_id: + hook_env_vars["_TFL_JOB_ID"] = tlab_job_id + command = command.replace("\r", "") - setup = setup.replace("\r", "") + if setup: + setup = setup.replace("\r", "") # Apply secure_filename to job_name if provided secure_job_name = None @@ -516,6 +610,42 @@ async def submit_job_to_cluster( cluster_name, user["id"], user["organization_id"] ) + # Handle mandatory storage mounts (skip for RunPod clusters) + storage_mounts = {} + platform_info = get_cluster_platform_info_util(actual_cluster_name) + is_runpod = False + if platform_info and platform_info.get("platform"): + platform = platform_info["platform"] + if platform == "multi-cloud": + from routes.instances.utils import ( + determine_actual_cloud_from_skypilot_status, + ) + + # Determine the actual cloud used by SkyPilot + actual_platform = determine_actual_cloud_from_skypilot_status( + actual_cluster_name + ) + platform = actual_platform if actual_platform else platform + is_runpod = platform == "runpod" + + if ( + not is_runpod + and os.getenv("TRANSFORMERLAB_BUCKET_NAME") + and os.getenv("TRANSFORMERLAB_BUCKET_SOURCE") + ): + import sky + + transformerlab_bucket = sky.Storage( + name=os.getenv("TRANSFORMERLAB_BUCKET_NAME"), + mode=sky.StorageMode.MOUNT, + source=os.getenv("TRANSFORMERLAB_BUCKET_SOURCE"), + persistent=True, + ) + storage_mounts["/workspace"] = transformerlab_bucket + print( + f"[Jobs] Added mandatory transformerlab bucket: {transformerlab_bucket}" + ) + # Default num_nodes to 1 if not provided or invalid try: if num_nodes is None: @@ -539,6 +669,8 @@ async def submit_job_to_cluster( zone=zone, job_name=secure_job_name, num_nodes=effective_num_nodes, + env_vars=hook_env_vars, + storage_mounts=storage_mounts, ) # Record usage event diff --git a/src/lattice/routes/jobs/utils.py b/src/lattice/routes/jobs/utils.py index e7f7f59..9a2326c 100644 --- a/src/lattice/routes/jobs/utils.py +++ b/src/lattice/routes/jobs/utils.py @@ -133,6 +133,8 @@ def submit_job_to_existing_cluster( zone: Optional[str] = None, job_name: Optional[str] = None, num_nodes: Optional[int] = None, + env_vars: Optional[dict] = None, + storage_mounts: Optional[dict] = None, ): try: # Create job name with metadata if it's a special job type @@ -148,14 +150,24 @@ def submit_job_to_existing_cluster( except Exception: effective_num_nodes = 1 + # Prepare environment variables + envs = None + if env_vars and isinstance(env_vars, dict): + envs = env_vars.copy() + task = sky.Task( name=final_job_name, run=command, setup=setup, num_nodes=effective_num_nodes, + envs=envs, ) if file_mounts: task.set_file_mounts(file_mounts) + + # Set storage mounts if provided + if storage_mounts: + task.set_storage_mounts(storage_mounts) resources_kwargs = {} if cpus: