Skip to content

Commit 4eef896

Browse files
nbroad1881mryab
andauthored
Accept strings for checkpoint type on download (#308)
--------- Co-authored-by: Max Ryabinin <[email protected]>
1 parent fbc32ea commit 4eef896

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.8"
15+
version = "1.5.9"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/resources/finetune.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from pathlib import Path
5-
from typing import Dict, List, Literal
5+
from typing import List, Dict, Literal
66

77
from rich import print as rprint
88

@@ -545,7 +545,7 @@ def download(
545545
*,
546546
output: Path | str | None = None,
547547
checkpoint_step: int | None = None,
548-
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
548+
checkpoint_type: DownloadCheckpointType | str = DownloadCheckpointType.DEFAULT,
549549
) -> FinetuneDownloadResult:
550550
"""
551551
Downloads compressed fine-tuned model or checkpoint to local disk.
@@ -558,7 +558,7 @@ def download(
558558
Defaults to None.
559559
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
560560
Defaults to -1 (download the final model)
561-
checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
561+
checkpoint_type (CheckpointType | str, optional): Specifies which checkpoint to download.
562562
Defaults to CheckpointType.DEFAULT.
563563
564564
Returns:
@@ -582,6 +582,16 @@ def download(
582582

583583
ft_job = self.retrieve(id)
584584

585+
# convert str to DownloadCheckpointType
586+
if isinstance(checkpoint_type, str):
587+
try:
588+
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
589+
except ValueError:
590+
enum_strs = ", ".join(e.value for e in DownloadCheckpointType)
591+
raise ValueError(
592+
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
593+
)
594+
585595
if isinstance(ft_job.training_type, FullTrainingType):
586596
if checkpoint_type != DownloadCheckpointType.DEFAULT:
587597
raise ValueError(
@@ -592,10 +602,11 @@ def download(
592602
if checkpoint_type == DownloadCheckpointType.DEFAULT:
593603
checkpoint_type = DownloadCheckpointType.MERGED
594604

595-
if checkpoint_type == DownloadCheckpointType.MERGED:
596-
url += f"&checkpoint={DownloadCheckpointType.MERGED.value}"
597-
elif checkpoint_type == DownloadCheckpointType.ADAPTER:
598-
url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}"
605+
if checkpoint_type in {
606+
DownloadCheckpointType.MERGED,
607+
DownloadCheckpointType.ADAPTER,
608+
}:
609+
url += f"&checkpoint={checkpoint_type.value}"
599610
else:
600611
raise ValueError(
601612
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"

0 commit comments

Comments
 (0)