diff --git a/pyproject.toml b/pyproject.toml index 94226c18..2b5ccafd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.5.7" +version = "1.5.8" authors = ["Together AI "] description = "Python client for Together's Cloud Platform!" readme = "README.md" diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 275d6839..8d0bf97e 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from typing import List, Literal +from typing import Dict, List, Literal from rich import print as rprint @@ -30,16 +30,8 @@ TrainingMethodSFT, TrainingType, ) -from together.types.finetune import ( - DownloadCheckpointType, - FinetuneEvent, - FinetuneEventType, -) -from together.utils import ( - get_event_step, - log_warn_once, - normalize_key, -) +from together.types.finetune import DownloadCheckpointType +from together.utils import log_warn_once, normalize_key _FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$" @@ -222,68 +214,38 @@ def create_finetune_request( return finetune_request -def _process_checkpoints_from_events( - events: List[FinetuneEvent], id: str +def _parse_raw_checkpoints( + checkpoints: List[Dict[str, str]], id: str ) -> List[FinetuneCheckpoint]: """ - Helper function to process events and create checkpoint list. + Helper function to process raw checkpoints and create checkpoint list. Args: - events (List[FinetuneEvent]): List of fine-tune events to process + checkpoints (List[Dict[str, str]]): List of raw checkpoints metadata id (str): Fine-tune job ID Returns: List[FinetuneCheckpoint]: List of available checkpoints """ - checkpoints: List[FinetuneCheckpoint] = [] - - for event in events: - event_type = event.type - - if event_type == FinetuneEventType.CHECKPOINT_SAVE: - step = get_event_step(event) - checkpoint_name = f"{id}:{step}" if step is not None else id - - checkpoints.append( - FinetuneCheckpoint( - type=( - f"Intermediate (step {step})" - if step is not None - else "Intermediate" - ), - timestamp=event.created_at, - name=checkpoint_name, - ) - ) - elif event_type == FinetuneEventType.JOB_COMPLETE: - if hasattr(event, "model_path"): - checkpoints.append( - FinetuneCheckpoint( - type=( - "Final Merged" - if hasattr(event, "adapter_path") - else "Final" - ), - timestamp=event.created_at, - name=id, - ) - ) - if hasattr(event, "adapter_path"): - checkpoints.append( - FinetuneCheckpoint( - type=( - "Final Adapter" if hasattr(event, "model_path") else "Final" - ), - timestamp=event.created_at, - name=id, - ) - ) + parsed_checkpoints = [] + for checkpoint in checkpoints: + step = checkpoint["step"] + checkpoint_type = checkpoint["checkpoint_type"] + checkpoint_name = ( + f"{id}:{step}" if "intermediate" in checkpoint_type.lower() else id + ) - # Sort by timestamp (newest first) - checkpoints.sort(key=lambda x: x.timestamp, reverse=True) + parsed_checkpoints.append( + FinetuneCheckpoint( + type=checkpoint_type, + timestamp=checkpoint["created_at"], + name=checkpoint_name, + ) + ) - return checkpoints + parsed_checkpoints.sort(key=lambda x: x.timestamp, reverse=True) + return parsed_checkpoints class FineTuning: @@ -561,8 +523,21 @@ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]: Returns: List[FinetuneCheckpoint]: List of available checkpoints """ - events = self.list_events(id).data or [] - return _process_checkpoints_from_events(events, id) + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = requestor.request( + options=TogetherRequest( + method="GET", + url=f"fine-tunes/{id}/checkpoints", + ), + stream=False, + ) + assert isinstance(response, TogetherResponse) + + raw_checkpoints = response.data["data"] + return _parse_raw_checkpoints(raw_checkpoints, id) def download( self, @@ -936,11 +911,9 @@ async def list_events(self, id: str) -> FinetuneListEvents: ), stream=False, ) + assert isinstance(events_response, TogetherResponse) - # FIXME: API returns "data" field with no object type (should be "list") - events_list = FinetuneListEvents(object="list", **events_response.data) - - return events_list + return FinetuneListEvents(**events_response.data) async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]: """ @@ -950,11 +923,23 @@ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]: id (str): Unique identifier of the fine-tune job to list checkpoints for Returns: - List[FinetuneCheckpoint]: Object containing list of available checkpoints + List[FinetuneCheckpoint]: List of available checkpoints """ - events_list = await self.list_events(id) - events = events_list.data or [] - return _process_checkpoints_from_events(events, id) + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="GET", + url=f"fine-tunes/{id}/checkpoints", + ), + stream=False, + ) + assert isinstance(response, TogetherResponse) + + raw_checkpoints = response.data["data"] + return _parse_raw_checkpoints(raw_checkpoints, id) async def download( self, id: str, *, output: str | None = None, checkpoint_step: int = -1