@@ -308,13 +308,19 @@ def launch_cluster_with_skypilot(
308308 envs ["_TFL_REMOTE_SKYPILOT_WORKSPACE" ] = os .getenv (
309309 "_TFL_REMOTE_SKYPILOT_WORKSPACE" , "true"
310310 )
311+ # Add TFL_STORAGE_URI if TRANSFORMERLAB_BUCKET_SOURCE is set
312+ if os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" ):
313+ envs ["TFL_STORAGE_URI" ] = os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" )
311314 else :
312315 envs = {
313316 "AWS_PROFILE" : os .getenv ("AWS_PROFILE" , "transformerlab-s3" ),
314317 "_TFL_REMOTE_SKYPILOT_WORKSPACE" : os .getenv (
315318 "_TFL_REMOTE_SKYPILOT_WORKSPACE" , "true"
316319 ),
317320 }
321+ # Add TFL_STORAGE_URI if TRANSFORMERLAB_BUCKET_SOURCE is set
322+ if os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" ):
323+ envs ["TFL_STORAGE_URI" ] = os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" )
318324
319325 # Merge launch hook environment variables with existing envs
320326 if env_vars and isinstance (env_vars , dict ):
@@ -412,49 +418,14 @@ def launch_cluster_with_skypilot(
412418
413419 storage_mounts [bucket .remote_path ] = storage_obj
414420
415- # Add mandatory transformerlab bucket (skip for runpod if disabled_mandatory_mounts is True)
416- should_add_mandatory_bucket = (
417- os .getenv ("TRANSFORMERLAB_BUCKET_NAME" )
418- and os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" )
419- and not disabled_mandatory_mounts
420- )
421-
422- if should_add_mandatory_bucket :
423- transformerlab_bucket = sky .Storage (
424- name = os .getenv ("TRANSFORMERLAB_BUCKET_NAME" ),
425- mode = sky .StorageMode .MOUNT ,
426- source = os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" ),
427- persistent = True ,
428- )
429- storage_mounts ["/workspace" ] = transformerlab_bucket
430- # Set storage mounts on the task
431- task .set_storage_mounts (storage_mounts )
421+ # Set storage mounts on the task if any buckets were added
422+ if storage_mounts :
423+ task .set_storage_mounts (storage_mounts )
432424
433425 except Exception as e :
434426 print (f"[SkyPilot] Warning: Failed to process storage buckets: { e } " )
435427 finally :
436428 db .close ()
437- else :
438- # Add mandatory transformerlab bucket (skip for runpod if disabled_mandatory_mounts is True)
439- should_add_mandatory_bucket = (
440- os .getenv ("TRANSFORMERLAB_BUCKET_NAME" )
441- and os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" )
442- and not disabled_mandatory_mounts
443- )
444-
445- if should_add_mandatory_bucket :
446- transformerlab_bucket = sky .Storage (
447- name = os .getenv ("TRANSFORMERLAB_BUCKET_NAME" ),
448- mode = sky .StorageMode .MOUNT ,
449- source = os .getenv ("TRANSFORMERLAB_BUCKET_SOURCE" ),
450- persistent = True ,
451- )
452- storage_mounts ["/workspace" ] = transformerlab_bucket
453- # Set storage mounts on the task
454- task .set_storage_mounts (storage_mounts )
455- print (
456- f"[SkyPilot] Added mandatory transformerlab bucket: { transformerlab_bucket } "
457- )
458429
459430 # If no cloud is specified, create a list of resources for all available clouds
460431 if not cloud :
0 commit comments