@@ -275,24 +275,6 @@ async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str:
275275 return config_map [inference_framework ]
276276
277277
278- def _include_safetensors_bin_or_pt (model_files : List [str ]) -> Optional [str ]:
279- """
280- This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files
281- based on which file type is present most often in the checkpoint folder. The most
282- frequently present file type is included.
283- In case of ties, priority is given to "*.safetensors", then "*.bin", then "*.pt".
284- """
285- num_safetensors = len ([f for f in model_files if f .endswith (".safetensors" )])
286- num_bin = len ([f for f in model_files if f .endswith (".bin" )])
287- num_pt = len ([f for f in model_files if f .endswith (".pt" )])
288- maximum = max (num_safetensors , num_bin , num_pt )
289- if num_safetensors == maximum :
290- return "*.safetensors"
291- if num_bin == maximum :
292- return "*.bin"
293- return "*.pt"
294-
295-
296278def _model_endpoint_entity_to_get_llm_model_endpoint_response (
297279 model_endpoint : ModelEndpoint ,
298280) -> GetLLMModelEndpointV1Response :
@@ -354,6 +336,10 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None:
354336 raise ObjectHasInvalidValueException (
355337 f"Only S3 paths are supported. Given checkpoint path: { checkpoint_path } ."
356338 )
339+ if checkpoint_path .endswith (".tar" ):
340+ raise ObjectHasInvalidValueException (
341+ f"Tar files are not supported. Given checkpoint path: { checkpoint_path } ."
342+ )
357343
358344
359345def get_checkpoint_path (model_name : str , checkpoint_path_override : Optional [str ]) -> str :
@@ -370,6 +356,14 @@ def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]
370356 return checkpoint_path
371357
372358
359+ def validate_checkpoint_files (checkpoint_files : List [str ]) -> None :
360+ """Require safetensors in the checkpoint path."""
361+ model_files = [f for f in checkpoint_files if "model" in f ]
362+ num_safetensors = len ([f for f in model_files if f .endswith (".safetensors" )])
363+ if num_safetensors == 0 :
364+ raise ObjectHasInvalidValueException ("No safetensors found in the checkpoint path." )
365+
366+
373367class CreateLLMModelBundleV1UseCase :
374368 def __init__ (
375369 self ,
@@ -557,27 +551,14 @@ def load_model_weights_sub_commands(
557551 else :
558552 s5cmd = "./s5cmd"
559553
560- base_path = checkpoint_path .split ("/" )[- 1 ]
561- if base_path .endswith (".tar" ):
562- # If the checkpoint file is a tar file, extract it into final_weights_folder
563- subcommands .extend (
564- [
565- f"{ s5cmd } cp { checkpoint_path } ." ,
566- f"mkdir -p { final_weights_folder } " ,
567- f"tar --no-same-owner -xf { base_path } -C { final_weights_folder } " ,
568- ]
569- )
570- else :
571- # Let's check whether to exclude "*.safetensors" or "*.bin" files
572- checkpoint_files = self .llm_artifact_gateway .list_files (checkpoint_path )
573- model_files = [f for f in checkpoint_files if "model" in f ]
574-
575- include_str = _include_safetensors_bin_or_pt (model_files )
576- file_selection_str = f"--include '*.model' --include '*.json' --include '{ include_str } ' --exclude 'optimizer*'"
577- subcommands .append (
578- f"{ s5cmd } --numworkers 512 cp --concurrency 10 { file_selection_str } { os .path .join (checkpoint_path , '*' )} { final_weights_folder } "
579- )
554+ checkpoint_files = self .llm_artifact_gateway .list_files (checkpoint_path )
555+ validate_checkpoint_files (checkpoint_files )
580556
557+ # filter to configs ('*.model' and '*.json') and weights ('*.safetensors')
558+ file_selection_str = "--include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*'"
559+ subcommands .append (
560+ f"{ s5cmd } --numworkers 512 cp --concurrency 10 { file_selection_str } { os .path .join (checkpoint_path , '*' )} { final_weights_folder } "
561+ )
581562 return subcommands
582563
583564 def load_model_files_sub_commands_trt_llm (
@@ -591,19 +572,9 @@ def load_model_files_sub_commands_trt_llm(
591572 See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt
592573 and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt
593574 """
594- subcommands = []
595-
596- base_path = checkpoint_path .split ("/" )[- 1 ]
597-
598- if base_path .endswith (".tar" ):
599- raise ObjectHasInvalidValueException (
600- "Checkpoint for TensorRT-LLM models must be a folder, not a tar file."
601- )
602- else :
603- subcommands .append (
604- f"./s5cmd --numworkers 512 cp --concurrency 50 { os .path .join (checkpoint_path , '*' )} ./"
605- )
606-
575+ subcommands = [
576+ f"./s5cmd --numworkers 512 cp --concurrency 50 { os .path .join (checkpoint_path , '*' )} ./"
577+ ]
607578 return subcommands
608579
609580 async def create_deepspeed_bundle (
0 commit comments