Skip to content
27 changes: 20 additions & 7 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import Dict, List, Literal
from typing import List, Literal, Union, Literal

from rich import print as rprint

Expand Down Expand Up @@ -545,7 +545,9 @@ def download(
*,
output: Path | str | None = None,
checkpoint_step: int | None = None,
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
checkpoint_type: Union[
DownloadCheckpointType, str
] = DownloadCheckpointType.DEFAULT,
) -> FinetuneDownloadResult:
"""
Downloads compressed fine-tuned model or checkpoint to local disk.
Expand All @@ -558,7 +560,7 @@ def download(
Defaults to None.
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
Defaults to -1 (download the final model)
checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
checkpoint_type (Union[CheckpointType, str], optional): Specifies which checkpoint to download.
Defaults to CheckpointType.DEFAULT.

Returns:
Expand All @@ -582,6 +584,16 @@ def download(

ft_job = self.retrieve(id)

# convert str to DownloadCheckpointType
if isinstance(checkpoint_type, str):
try:
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
except ValueError:
enum_strs = ", ".join(e.value for e in DownloadCheckpointType)
raise ValueError(
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
)

if isinstance(ft_job.training_type, FullTrainingType):
if checkpoint_type != DownloadCheckpointType.DEFAULT:
raise ValueError(
Expand All @@ -592,10 +604,11 @@ def download(
if checkpoint_type == DownloadCheckpointType.DEFAULT:
checkpoint_type = DownloadCheckpointType.MERGED

if checkpoint_type == DownloadCheckpointType.MERGED:
url += f"&checkpoint={DownloadCheckpointType.MERGED.value}"
elif checkpoint_type == DownloadCheckpointType.ADAPTER:
url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}"
if checkpoint_type in {
DownloadCheckpointType.MERGED,
DownloadCheckpointType.ADAPTER,
}:
url += f"&checkpoint={checkpoint_type.value}"
else:
raise ValueError(
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"
Expand Down
Loading