Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.5.7"
version = "1.5.8"
authors = ["Together AI <[email protected]>"]
description = "Python client for Together's Cloud Platform!"
readme = "README.md"
Expand Down
127 changes: 56 additions & 71 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import List, Literal
from typing import Dict, List, Literal

from rich import print as rprint

Expand Down Expand Up @@ -30,16 +30,8 @@
TrainingMethodSFT,
TrainingType,
)
from together.types.finetune import (
DownloadCheckpointType,
FinetuneEvent,
FinetuneEventType,
)
from together.utils import (
get_event_step,
log_warn_once,
normalize_key,
)
from together.types.finetune import DownloadCheckpointType
from together.utils import log_warn_once, normalize_key


_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
Expand Down Expand Up @@ -222,68 +214,38 @@ def create_finetune_request(
return finetune_request


def _process_checkpoints_from_events(
events: List[FinetuneEvent], id: str
def _parse_raw_checkpoints(
checkpoints: List[Dict[str, str]], id: str
) -> List[FinetuneCheckpoint]:
"""
Helper function to process events and create checkpoint list.
Helper function to process raw checkpoints and create checkpoint list.

Args:
events (List[FinetuneEvent]): List of fine-tune events to process
checkpoints (List[Dict[str, str]]): List of raw checkpoints metadata
id (str): Fine-tune job ID

Returns:
List[FinetuneCheckpoint]: List of available checkpoints
"""
checkpoints: List[FinetuneCheckpoint] = []

for event in events:
event_type = event.type

if event_type == FinetuneEventType.CHECKPOINT_SAVE:
step = get_event_step(event)
checkpoint_name = f"{id}:{step}" if step is not None else id

checkpoints.append(
FinetuneCheckpoint(
type=(
f"Intermediate (step {step})"
if step is not None
else "Intermediate"
),
timestamp=event.created_at,
name=checkpoint_name,
)
)
elif event_type == FinetuneEventType.JOB_COMPLETE:
if hasattr(event, "model_path"):
checkpoints.append(
FinetuneCheckpoint(
type=(
"Final Merged"
if hasattr(event, "adapter_path")
else "Final"
),
timestamp=event.created_at,
name=id,
)
)

if hasattr(event, "adapter_path"):
checkpoints.append(
FinetuneCheckpoint(
type=(
"Final Adapter" if hasattr(event, "model_path") else "Final"
),
timestamp=event.created_at,
name=id,
)
)
parsed_checkpoints = []
for checkpoint in checkpoints:
step = checkpoint["step"]
checkpoint_type = checkpoint["checkpoint_type"]
checkpoint_name = (
f"{id}:{step}" if "intermediate" in checkpoint_type.lower() else id
)

# Sort by timestamp (newest first)
checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
parsed_checkpoints.append(
FinetuneCheckpoint(
type=checkpoint_type,
timestamp=checkpoint["created_at"],
name=checkpoint_name,
)
)

return checkpoints
parsed_checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
return parsed_checkpoints


class FineTuning:
Expand Down Expand Up @@ -561,8 +523,21 @@ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
Returns:
List[FinetuneCheckpoint]: List of available checkpoints
"""
events = self.list_events(id).data or []
return _process_checkpoints_from_events(events, id)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = requestor.request(
options=TogetherRequest(
method="GET",
url=f"fine-tunes/{id}/checkpoints",
),
stream=False,
)
assert isinstance(response, TogetherResponse)

raw_checkpoints = response.data["data"]
return _parse_raw_checkpoints(raw_checkpoints, id)

def download(
self,
Expand Down Expand Up @@ -936,11 +911,9 @@ async def list_events(self, id: str) -> FinetuneListEvents:
),
stream=False,
)
assert isinstance(events_response, TogetherResponse)

# FIXME: API returns "data" field with no object type (should be "list")
events_list = FinetuneListEvents(object="list", **events_response.data)

return events_list
return FinetuneListEvents(**events_response.data)

async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
"""
Expand All @@ -950,11 +923,23 @@ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
id (str): Unique identifier of the fine-tune job to list checkpoints for

Returns:
List[FinetuneCheckpoint]: Object containing list of available checkpoints
List[FinetuneCheckpoint]: List of available checkpoints
"""
events_list = await self.list_events(id)
events = events_list.data or []
return _process_checkpoints_from_events(events, id)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="GET",
url=f"fine-tunes/{id}/checkpoints",
),
stream=False,
)
assert isinstance(response, TogetherResponse)

raw_checkpoints = response.data["data"]
return _parse_raw_checkpoints(raw_checkpoints, id)

async def download(
self, id: str, *, output: str | None = None, checkpoint_step: int = -1
Expand Down