1414import queue
1515import threading
1616from werkzeug .utils import secure_filename
17+ from sqlalchemy .orm import Session
18+ from config import get_db
1719from models import (
1820 JobQueueResponse ,
1921 JobLogsResponse ,
@@ -400,9 +402,11 @@ async def submit_job_to_cluster(
400402 dir_name : Optional [str ] = Form (None ),
401403 uploaded_dir_path : Optional [str ] = Form (None ),
402404 num_nodes : Optional [int ] = Form (None ),
405+ tlab_job_id : Optional [str ] = Form (None ),
403406 yaml_file : Optional [UploadFile ] = File (None ),
404407 user : dict = Depends (get_user_or_api_key ),
405408 scope_check : dict = Depends (require_scope ("compute:write" )),
409+ db : Session = Depends (get_db ),
406410):
407411 try :
408412 # Parse YAML configuration if provided
@@ -445,6 +449,7 @@ async def submit_job_to_cluster(
445449 "job_name" : job_name ,
446450 "dir_name" : dir_name ,
447451 "num_nodes" : num_nodes ,
452+ "tlab_job_id" : tlab_job_id ,
448453 }
449454
450455 # Override with YAML values where form parameters are None (excluding resource fields)
@@ -478,6 +483,7 @@ async def submit_job_to_cluster(
478483 job_name = final_config ["job_name" ]
479484 dir_name = final_config ["dir_name" ]
480485 num_nodes = final_config ["num_nodes" ]
486+ tlab_job_id = final_config ["tlab_job_id" ]
481487
482488 file_mounts = None
483489
@@ -503,8 +509,96 @@ async def submit_job_to_cluster(
503509 # Mount the entire directory at ~/<base_name>
504510 file_mounts = {f"~/{ base_name } " : uploaded_dir_path }
505511
512+ # Initialize hook environment variables
513+ hook_env_vars = {}
514+
515+ # Handle launch hooks for the organization
516+ organization_id = user .get ("organization_id" )
517+ if organization_id :
518+ from db .db_models import LaunchHook , LaunchHookFile , TeamMembership
519+
520+ # Get user's team ID
521+ user_team = (
522+ db .query (TeamMembership )
523+ .filter (
524+ TeamMembership .organization_id == organization_id ,
525+ TeamMembership .user_id == user .get ("id" ),
526+ )
527+ .first ()
528+ )
529+ user_team_id = user_team .team_id if user_team else None
530+
531+ # Get all active launch hooks for the organization
532+ active_hooks = (
533+ db .query (LaunchHook )
534+ .filter (
535+ LaunchHook .organization_id == organization_id ,
536+ LaunchHook .is_active == True , # noqa: E712
537+ )
538+ .all ()
539+ )
540+
541+ # Filter hooks based on team access
542+ accessible_hooks = []
543+ for hook in active_hooks :
544+ # If no team restrictions (allowed_team_ids is None), hook is accessible to all
545+ if hook .allowed_team_ids is None :
546+ accessible_hooks .append (hook )
547+ # If user has no team, they can't access team-restricted hooks
548+ elif user_team_id is None :
549+ continue
550+ # If user's team is in the allowed list, they can access the hook
551+ elif user_team_id in hook .allowed_team_ids :
552+ accessible_hooks .append (hook )
553+
554+ if accessible_hooks :
555+ # Initialize file_mounts if not already set
556+ if file_mounts is None :
557+ file_mounts = {}
558+
559+ # Collect all setup commands and environment variables from accessible hooks
560+ hook_setup_commands = []
561+
562+ for hook in accessible_hooks :
563+ # Add setup commands from this hook
564+ if hook .setup_commands :
565+ hook_setup_commands .append (hook .setup_commands )
566+
567+ # Collect environment variables from this hook
568+ if hook .env_vars and isinstance (hook .env_vars , dict ):
569+ hook_env_vars .update (hook .env_vars )
570+
571+ # Get files for this hook
572+ hook_files = (
573+ db .query (LaunchHookFile )
574+ .filter (
575+ LaunchHookFile .launch_hook_id == hook .id ,
576+ LaunchHookFile .is_active == True , # noqa: E712
577+ )
578+ .all ()
579+ )
580+
581+ # Mount each file to ~/hooks/<filename>
582+ for hook_file in hook_files :
583+ if os .path .exists (hook_file .file_path ):
584+ mount_path = f"~/hooks/{ hook_file .original_filename } "
585+ file_mounts [mount_path ] = hook_file .file_path
586+
587+ # Prepend hook setup commands to the main setup commands
588+ if hook_setup_commands :
589+ combined_setup = ";" .join (hook_setup_commands )
590+ if setup :
591+ setup = f"{ combined_setup } ;{ setup } "
592+ else :
593+ setup = combined_setup
594+
595+ # Set _TFL_JOB_ID environment variable if tlab_job_id is provided
596+ if tlab_job_id :
597+ hook_env_vars ["_TFL_JOB_ID" ] = tlab_job_id
598+
506599 command = command .replace ("\r " , "" )
507- setup = setup .replace ("\r " , "" )
600+ if setup :
601+ setup = setup .replace ("\r " , "" )
508602
509603 # Apply secure_filename to job_name if provided
510604 secure_job_name = None
@@ -516,6 +610,42 @@ async def submit_job_to_cluster(
516610 cluster_name , user ["id" ], user ["organization_id" ]
517611 )
518612
613+ # Handle mandatory storage mounts (skip for RunPod clusters)
614+ storage_mounts = {}
615+ platform_info = get_cluster_platform_info_util (actual_cluster_name )
616+ is_runpod = False
617+ if platform_info and platform_info .get ("platform" ):
618+ platform = platform_info ["platform" ]
619+ if platform == "multi-cloud" :
620+ from routes .instances .utils import (
621+ determine_actual_cloud_from_skypilot_status ,
622+ )
623+
624+ # Determine the actual cloud used by SkyPilot
625+ actual_platform = determine_actual_cloud_from_skypilot_status (
626+ actual_cluster_name
627+ )
628+ platform = actual_platform if actual_platform else platform
629+ is_runpod = platform == "runpod"
630+
631+ if (
632+ not is_runpod
633+ and os .getenv ("TRANSFORMERLAB_BUCKET_NAME" )
634+ and os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" )
635+ ):
636+ import sky
637+
638+ transformerlab_bucket = sky .Storage (
639+ name = os .getenv ("TRANSFORMERLAB_BUCKET_NAME" ),
640+ mode = sky .StorageMode .MOUNT ,
641+ source = os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" ),
642+ persistent = True ,
643+ )
644+ storage_mounts ["/workspace" ] = transformerlab_bucket
645+ print (
646+ f"[Jobs] Added mandatory transformerlab bucket: { transformerlab_bucket } "
647+ )
648+
519649 # Default num_nodes to 1 if not provided or invalid
520650 try :
521651 if num_nodes is None :
@@ -539,6 +669,8 @@ async def submit_job_to_cluster(
539669 zone = zone ,
540670 job_name = secure_job_name ,
541671 num_nodes = effective_num_nodes ,
672+ env_vars = hook_env_vars ,
673+ storage_mounts = storage_mounts ,
542674 )
543675
544676 # Record usage event
0 commit comments