Skip to content

Commit fde110f

Browse files
committed
Use /checkpoints instead of events parsing
1 parent c09481b commit fde110f

File tree

2 files changed

+99
-4
lines changed

2 files changed

+99
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.7"
15+
version = "1.5.8"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/resources/finetune.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from pathlib import Path
5-
from typing import List, Literal
5+
from typing import Dict, List, Literal
66

77
from rich import print as rprint
88

@@ -222,6 +222,49 @@ def create_finetune_request(
222222
return finetune_request
223223

224224

225+
def _parse_raw_checkpoints(
226+
checkpoints: List[Dict[str, str]], id: str
227+
) -> List[FinetuneCheckpoint]:
228+
"""
229+
Helper function to process raw checkpoints and create checkpoint list.
230+
231+
Args:
232+
checkpoints (List[Dict[str, str]]): List of raw checkpoints metadata
233+
id (str): Fine-tune job ID
234+
235+
Returns:
236+
List[FinetuneCheckpoint]: List of available checkpoints
237+
"""
238+
had_adapters = any(ckpt["path"].endswith("_adapter") for ckpt in checkpoints)
239+
240+
parsed_checkpoints = []
241+
for checkpoint in checkpoints:
242+
checkpoint_path = checkpoint["path"]
243+
step = checkpoint["step"]
244+
245+
is_final = int(step) == 0
246+
checkpoint_name = f"{id}:step" if step else id
247+
248+
if is_final:
249+
if checkpoint_path.endswith("_adapter"):
250+
checkpoint_type = "Final Adapter"
251+
else:
252+
checkpoint_type = "Final Merged" if had_adapters else "Final"
253+
else:
254+
checkpoint_type = "Intermediate"
255+
256+
parsed_checkpoints.append(
257+
FinetuneCheckpoint(
258+
type=checkpoint_type,
259+
timestamp=checkpoint["created_at"],
260+
name=checkpoint_name,
261+
)
262+
)
263+
264+
parsed_checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
265+
return parsed_checkpoints
266+
267+
225268
def _process_checkpoints_from_events(
226269
events: List[FinetuneEvent], id: str
227270
) -> List[FinetuneCheckpoint]:
@@ -551,7 +594,7 @@ def list_events(self, id: str) -> FinetuneListEvents:
551594

552595
return FinetuneListEvents(**response.data)
553596

554-
def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
597+
def list_checkpoints_from_events(self, id: str) -> List[FinetuneCheckpoint]:
555598
"""
556599
List available checkpoints for a fine-tuning job
557600
@@ -564,6 +607,32 @@ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
564607
events = self.list_events(id).data or []
565608
return _process_checkpoints_from_events(events, id)
566609

610+
def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
611+
"""
612+
List available checkpoints for a fine-tuning job
613+
614+
Args:
615+
id (str): Unique identifier of the fine-tune job to list checkpoints for
616+
617+
Returns:
618+
List[FinetuneCheckpoint]: List of available checkpoints
619+
"""
620+
requestor = api_requestor.APIRequestor(
621+
client=self._client,
622+
)
623+
624+
response, _, _ = requestor.request(
625+
options=TogetherRequest(
626+
method="GET",
627+
url=f"fine-tunes/{id}/checkpoints",
628+
),
629+
stream=False,
630+
)
631+
assert isinstance(response, TogetherResponse)
632+
633+
raw_checkpoints = response.data["data"]
634+
return _parse_raw_checkpoints(raw_checkpoints, id)
635+
567636
def download(
568637
self,
569638
id: str,
@@ -942,7 +1011,7 @@ async def list_events(self, id: str) -> FinetuneListEvents:
9421011

9431012
return events_list
9441013

945-
async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
1014+
async def list_checkpoints_from_events(self, id: str) -> List[FinetuneCheckpoint]:
9461015
"""
9471016
List available checkpoints for a fine-tuning job
9481017
@@ -956,6 +1025,32 @@ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
9561025
events = events_list.data or []
9571026
return _process_checkpoints_from_events(events, id)
9581027

1028+
async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
1029+
"""
1030+
List available checkpoints for a fine-tuning job
1031+
1032+
Args:
1033+
id (str): Unique identifier of the fine-tune job to list checkpoints for
1034+
1035+
Returns:
1036+
List[FinetuneCheckpoint]: List of available checkpoints
1037+
"""
1038+
requestor = api_requestor.APIRequestor(
1039+
client=self._client,
1040+
)
1041+
1042+
response, _, _ = await requestor.arequest(
1043+
options=TogetherRequest(
1044+
method="GET",
1045+
url=f"fine-tunes/{id}/checkpoints",
1046+
),
1047+
stream=False,
1048+
)
1049+
assert isinstance(response, TogetherResponse)
1050+
1051+
raw_checkpoints = response.data["data"]
1052+
return _parse_raw_checkpoints(raw_checkpoints, id)
1053+
9591054
async def download(
9601055
self, id: str, *, output: str | None = None, checkpoint_step: int = -1
9611056
) -> str:

0 commit comments

Comments
 (0)