22
33import re
44from pathlib import Path
5- from typing import Dict , List , Literal
5+ from typing import List , Dict , Literal
66
77from rich import print as rprint
88
@@ -545,7 +545,7 @@ def download(
545545 * ,
546546 output : Path | str | None = None ,
547547 checkpoint_step : int | None = None ,
548- checkpoint_type : DownloadCheckpointType = DownloadCheckpointType .DEFAULT ,
548+ checkpoint_type : DownloadCheckpointType | str = DownloadCheckpointType .DEFAULT ,
549549 ) -> FinetuneDownloadResult :
550550 """
551551 Downloads compressed fine-tuned model or checkpoint to local disk.
@@ -558,7 +558,7 @@ def download(
558558 Defaults to None.
559559 checkpoint_step (int, optional): Specifies step number for checkpoint to download.
560560 Defaults to -1 (download the final model)
561- checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
561+ checkpoint_type (CheckpointType | str , optional): Specifies which checkpoint to download.
562562 Defaults to CheckpointType.DEFAULT.
563563
564564 Returns:
@@ -582,6 +582,16 @@ def download(
582582
583583 ft_job = self .retrieve (id )
584584
585+ # convert str to DownloadCheckpointType
586+ if isinstance (checkpoint_type , str ):
587+ try :
588+ checkpoint_type = DownloadCheckpointType (checkpoint_type .lower ())
589+ except ValueError :
590+ enum_strs = ", " .join (e .value for e in DownloadCheckpointType )
591+ raise ValueError (
592+ f"Invalid checkpoint type: { checkpoint_type } . Choose one of {{{ enum_strs } }}."
593+ )
594+
585595 if isinstance (ft_job .training_type , FullTrainingType ):
586596 if checkpoint_type != DownloadCheckpointType .DEFAULT :
587597 raise ValueError (
@@ -592,10 +602,11 @@ def download(
592602 if checkpoint_type == DownloadCheckpointType .DEFAULT :
593603 checkpoint_type = DownloadCheckpointType .MERGED
594604
595- if checkpoint_type == DownloadCheckpointType .MERGED :
596- url += f"&checkpoint={ DownloadCheckpointType .MERGED .value } "
597- elif checkpoint_type == DownloadCheckpointType .ADAPTER :
598- url += f"&checkpoint={ DownloadCheckpointType .ADAPTER .value } "
605+ if checkpoint_type in {
606+ DownloadCheckpointType .MERGED ,
607+ DownloadCheckpointType .ADAPTER ,
608+ }:
609+ url += f"&checkpoint={ checkpoint_type .value } "
599610 else :
600611 raise ValueError (
601612 f"Invalid checkpoint type for LoRATrainingType: { checkpoint_type } "
0 commit comments