diff --git a/README.md b/README.md index 408da2fd..a54430b0 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ print(response.choices[0].message.content) response = client.chat.completions.create( model="meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", messages=[{ - "role": "user", + "role": "user", "content": [ { "type": "text", @@ -91,7 +91,7 @@ response = client.chat.completions.create( "role": "user", "content": [ { - "type": "text", + "type": "text", "text": "Compare these two images." }, { diff --git a/pyproject.toml b/pyproject.toml index a77fb7f7..e0877f97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.4.2" +version = "1.4.3" authors = [ "Together AI " ] diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 7bc02744..ad81339d 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -1,9 +1,10 @@ from __future__ import annotations import json -from datetime import datetime +from datetime import datetime, timezone from textwrap import wrap from typing import Any, Literal +import re import click from click.core import ParameterSource # type: ignore[attr-defined] @@ -17,8 +18,13 @@ log_warn, log_warn_once, parse_timestamp, + format_timestamp, +) +from together.types.finetune import ( + DownloadCheckpointType, + FinetuneTrainingLimits, + FinetuneEventType, ) -from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits _CONFIRMATION_MESSAGE = ( @@ -126,6 +132,14 @@ def fine_tuning(ctx: click.Context) -> None: help="Whether to mask the user messages in conversational data or prompts in instruction data. " "`auto` will automatically determine whether to mask the inputs based on the data format.", ) +@click.option( + "--from-checkpoint", + type=str, + default=None, + help="The checkpoint identifier to continue training from a previous fine-tuning job. " + "The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. " + "The step value is optional, without it the final checkpoint will be used.", +) def create( ctx: click.Context, training_file: str, @@ -152,6 +166,7 @@ def create( wandb_name: str, confirm: bool, train_on_inputs: bool | Literal["auto"], + from_checkpoint: str, ) -> None: """Start fine-tuning""" client: Together = ctx.obj @@ -180,6 +195,7 @@ def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + from_checkpoint=from_checkpoint, ) model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits( @@ -261,7 +277,9 @@ def list(ctx: click.Context) -> None: response.data = response.data or [] - response.data.sort(key=lambda x: parse_timestamp(x.created_at or "")) + # Use a default datetime for None values to make sure the key function always returns a comparable value + epoch_start = datetime.fromtimestamp(0, tz=timezone.utc) + response.data.sort(key=lambda x: parse_timestamp(x.created_at or "") or epoch_start) display_list = [] for i in response.data: @@ -344,6 +362,34 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None: click.echo(table) +@fine_tuning.command() +@click.pass_context +@click.argument("fine_tune_id", type=str, required=True) +def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None: + """List available checkpoints for a fine-tuning job""" + client: Together = ctx.obj + + checkpoints = client.fine_tuning.list_checkpoints(fine_tune_id) + + display_list = [] + for checkpoint in checkpoints: + display_list.append( + { + "Type": checkpoint.type, + "Timestamp": format_timestamp(checkpoint.timestamp), + "Name": checkpoint.name, + } + ) + + if display_list: + click.echo(f"Job {fine_tune_id} contains the following checkpoints:") + table = tabulate(display_list, headers="keys", tablefmt="grid") + click.echo(table) + click.echo("\nTo download a checkpoint, use `together fine-tuning download`") + else: + click.echo(f"No checkpoints found for job {fine_tune_id}") + + @fine_tuning.command() @click.pass_context @click.argument("fine_tune_id", type=str, required=True) @@ -358,7 +404,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None: "--checkpoint-step", type=int, required=False, - default=-1, + default=None, help="Download fine-tuning checkpoint. Defaults to latest.", ) @click.option( @@ -372,7 +418,7 @@ def download( ctx: click.Context, fine_tune_id: str, output_dir: str, - checkpoint_step: int, + checkpoint_step: int | None, checkpoint_type: DownloadCheckpointType, ) -> None: """Download fine-tuning checkpoint""" diff --git a/src/together/legacy/finetune.py b/src/together/legacy/finetune.py index fe53be0e..a8a973bb 100644 --- a/src/together/legacy/finetune.py +++ b/src/together/legacy/finetune.py @@ -161,7 +161,7 @@ def download( cls, fine_tune_id: str, output: str | None = None, - step: int = -1, + step: int | None = None, ) -> Dict[str, Any]: """Legacy finetuning download function.""" diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index b58cdae2..11d445db 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -1,7 +1,8 @@ from __future__ import annotations +import re from pathlib import Path -from typing import Literal +from typing import Literal, List from rich import print as rprint @@ -22,9 +23,20 @@ TrainingType, FinetuneLRScheduler, FinetuneLinearLRSchedulerArgs, + FinetuneCheckpoint, ) -from together.types.finetune import DownloadCheckpointType -from together.utils import log_warn_once, normalize_key +from together.types.finetune import ( + DownloadCheckpointType, + FinetuneEventType, + FinetuneEvent, +) +from together.utils import ( + log_warn_once, + normalize_key, + get_event_step, +) + +_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$" def createFinetuneRequest( @@ -52,6 +64,7 @@ def createFinetuneRequest( wandb_project_name: str | None = None, wandb_name: str | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + from_checkpoint: str | None = None, ) -> FinetuneRequest: if batch_size == "max": log_warn_once( @@ -125,11 +138,76 @@ def createFinetuneRequest( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + from_checkpoint=from_checkpoint, ) return finetune_request +def _process_checkpoints_from_events( + events: List[FinetuneEvent], id: str +) -> List[FinetuneCheckpoint]: + """ + Helper function to process events and create checkpoint list. + + Args: + events (List[FinetuneEvent]): List of fine-tune events to process + id (str): Fine-tune job ID + + Returns: + List[FinetuneCheckpoint]: List of available checkpoints + """ + checkpoints: List[FinetuneCheckpoint] = [] + + for event in events: + event_type = event.type + + if event_type == FinetuneEventType.CHECKPOINT_SAVE: + step = get_event_step(event) + checkpoint_name = f"{id}:{step}" if step is not None else id + + checkpoints.append( + FinetuneCheckpoint( + type=( + f"Intermediate (step {step})" + if step is not None + else "Intermediate" + ), + timestamp=event.created_at, + name=checkpoint_name, + ) + ) + elif event_type == FinetuneEventType.JOB_COMPLETE: + if hasattr(event, "model_path"): + checkpoints.append( + FinetuneCheckpoint( + type=( + "Final Merged" + if hasattr(event, "adapter_path") + else "Final" + ), + timestamp=event.created_at, + name=id, + ) + ) + + if hasattr(event, "adapter_path"): + checkpoints.append( + FinetuneCheckpoint( + type=( + "Final Adapter" if hasattr(event, "model_path") else "Final" + ), + timestamp=event.created_at, + name=id, + ) + ) + + # Sort by timestamp (newest first) + checkpoints.sort(key=lambda x: x.timestamp, reverse=True) + + return checkpoints + + class FineTuning: def __init__(self, client: TogetherClient) -> None: self._client = client @@ -162,6 +240,7 @@ def create( verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + from_checkpoint: str | None = None, ) -> FinetuneResponse: """ Method to initiate a fine-tuning job @@ -207,6 +286,9 @@ def create( For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. Defaults to "auto". + from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job. + The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. + The step value is optional, without it the final checkpoint will be used. Returns: FinetuneResponse: Object containing information about fine-tuning job. @@ -244,6 +326,7 @@ def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + from_checkpoint=from_checkpoint, ) if verbose: @@ -366,17 +449,29 @@ def list_events(self, id: str) -> FinetuneListEvents: ), stream=False, ) - assert isinstance(response, TogetherResponse) return FinetuneListEvents(**response.data) + def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]: + """ + List available checkpoints for a fine-tuning job + + Args: + id (str): Unique identifier of the fine-tune job to list checkpoints for + + Returns: + List[FinetuneCheckpoint]: List of available checkpoints + """ + events = self.list_events(id).data or [] + return _process_checkpoints_from_events(events, id) + def download( self, id: str, *, output: Path | str | None = None, - checkpoint_step: int = -1, + checkpoint_step: int | None = None, checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT, ) -> FinetuneDownloadResult: """ @@ -397,9 +492,19 @@ def download( FinetuneDownloadResult: Object containing downloaded model metadata """ + if re.match(_FT_JOB_WITH_STEP_REGEX, id) is not None: + if checkpoint_step is None: + checkpoint_step = int(id.split(":")[1]) + id = id.split(":")[0] + else: + raise ValueError( + "Fine-tuning job ID {id} contains a colon to specify the step to download, but `checkpoint_step` " + "was also set. Remove one of the step specifiers to proceed." + ) + url = f"finetune/download?ft_id={id}" - if checkpoint_step > 0: + if checkpoint_step is not None: url += f"&checkpoint_step={checkpoint_step}" ft_job = self.retrieve(id) @@ -503,6 +608,7 @@ async def create( verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + from_checkpoint: str | None = None, ) -> FinetuneResponse: """ Async method to initiate a fine-tuning job @@ -548,6 +654,9 @@ async def create( For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. Defaults to "auto". + from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job. + The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. + The step value is optional, without it the final checkpoint will be used. Returns: FinetuneResponse: Object containing information about fine-tuning job. @@ -585,6 +694,7 @@ async def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + from_checkpoint=from_checkpoint, ) if verbose: @@ -687,30 +797,45 @@ async def cancel(self, id: str) -> FinetuneResponse: async def list_events(self, id: str) -> FinetuneListEvents: """ - Async method to lists events of a fine-tune job + List fine-tuning events Args: - id (str): Fine-tune ID to list events for. A string that starts with `ft-`. + id (str): Unique identifier of the fine-tune job to list events for Returns: - FinetuneListEvents: Object containing list of fine-tune events + FinetuneListEvents: Object containing list of fine-tune job events """ requestor = api_requestor.APIRequestor( client=self._client, ) - response, _, _ = await requestor.arequest( + events_response, _, _ = await requestor.arequest( options=TogetherRequest( method="GET", - url=f"fine-tunes/{id}/events", + url=f"fine-tunes/{normalize_key(id)}/events", ), stream=False, ) - assert isinstance(response, TogetherResponse) + # FIXME: API returns "data" field with no object type (should be "list") + events_list = FinetuneListEvents(object="list", **events_response.data) - return FinetuneListEvents(**response.data) + return events_list + + async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]: + """ + List available checkpoints for a fine-tuning job + + Args: + id (str): Unique identifier of the fine-tune job to list checkpoints for + + Returns: + List[FinetuneCheckpoint]: Object containing list of available checkpoints + """ + events_list = await self.list_events(id) + events = events_list.data or [] + return _process_checkpoints_from_events(events, id) async def download( self, id: str, *, output: str | None = None, checkpoint_step: int = -1 diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index c3100cd1..1a7419a5 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -31,6 +31,7 @@ FileType, ) from together.types.finetune import ( + FinetuneCheckpoint, FinetuneDownloadResult, FinetuneLinearLRSchedulerArgs, FinetuneList, @@ -59,6 +60,7 @@ "ChatCompletionResponse", "EmbeddingRequest", "EmbeddingResponse", + "FinetuneCheckpoint", "FinetuneRequest", "FinetuneResponse", "FinetuneList", diff --git a/src/together/types/endpoints.py b/src/together/types/endpoints.py index 3f52831a..0db1de21 100644 --- a/src/together/types/endpoints.py +++ b/src/together/types/endpoints.py @@ -86,9 +86,9 @@ class BaseEndpoint(TogetherJSONModel): model: str = Field(description="The model deployed on this endpoint") type: str = Field(description="The type of endpoint") owner: str = Field(description="The owner of this endpoint") - state: Literal["PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "FAILED", "ERROR"] = ( - Field(description="Current state of the endpoint") - ) + state: Literal[ + "PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "FAILED", "ERROR" + ] = Field(description="Current state of the endpoint") created_at: datetime = Field(description="Timestamp when the endpoint was created") diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 05bc8c42..e3811292 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -178,6 +178,8 @@ class FinetuneRequest(BaseModel): training_type: FullTrainingType | LoRATrainingType | None = None # train on inputs train_on_inputs: StrictBool | Literal["auto"] = "auto" + # from step + from_checkpoint: str class FinetuneResponse(BaseModel): @@ -256,6 +258,7 @@ class FinetuneResponse(BaseModel): training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines") training_file_size: int | None = Field(None, alias="TrainingFileSize") train_on_inputs: StrictBool | Literal["auto"] | None = "auto" + from_checkpoint: str | None = None @field_validator("training_type") @classmethod @@ -320,3 +323,16 @@ class FinetuneLRScheduler(BaseModel): class FinetuneLinearLRSchedulerArgs(BaseModel): min_lr_ratio: float | None = 0.0 + + +class FinetuneCheckpoint(BaseModel): + """ + Fine-tuning checkpoint information + """ + + # checkpoint type (e.g. "Intermediate", "Final", "Final Merged", "Final Adapter") + type: str + # timestamp when the checkpoint was created + timestamp: str + # checkpoint name/identifier + name: str diff --git a/src/together/utils/__init__.py b/src/together/utils/__init__.py index 0e59966f..a9e88c3b 100644 --- a/src/together/utils/__init__.py +++ b/src/together/utils/__init__.py @@ -8,6 +8,8 @@ finetune_price_to_dollars, normalize_key, parse_timestamp, + format_timestamp, + get_event_step, ) @@ -23,6 +25,8 @@ "enforce_trailing_slash", "normalize_key", "parse_timestamp", + "format_timestamp", + "get_event_step", "finetune_price_to_dollars", "convert_bytes", "convert_unix_timestamp", diff --git a/src/together/utils/tools.py b/src/together/utils/tools.py index 7ac68000..2e84307a 100644 --- a/src/together/utils/tools.py +++ b/src/together/utils/tools.py @@ -3,6 +3,8 @@ import logging import os from datetime import datetime +import re +from typing import Any logger = logging.getLogger("together") @@ -23,18 +25,67 @@ def normalize_key(key: str) -> str: return key.replace("/", "--").replace("_", "-").replace(" ", "-").lower() -def parse_timestamp(timestamp: str) -> datetime: +def parse_timestamp(timestamp: str) -> datetime | None: + """Parse a timestamp string into a datetime object or None if the string is empty. + + Args: + timestamp (str): Timestamp + + Returns: + datetime | None: Parsed datetime, or None if the string is empty + """ + if timestamp == "": + return None + formats = ["%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ"] for fmt in formats: try: return datetime.strptime(timestamp, fmt) except ValueError: continue + raise ValueError("Timestamp does not match any expected format") -# Convert fine-tune nano-dollar price to dollars +def format_timestamp(timestamp_str: str) -> str: + """Format timestamp to a readable date string. + + Args: + timestamp: A timestamp string + + Returns: + str: Formatted timestamp string (MM/DD/YYYY, HH:MM AM/PM) + """ + timestamp = parse_timestamp(timestamp_str) + if timestamp is None: + return "" + return timestamp.strftime("%m/%d/%Y, %I:%M %p") + + +def get_event_step(event: Any) -> str | None: + """Extract the step number from a checkpoint event. + + Args: + event: A checkpoint event object + + Returns: + str | None: The step number as a string, or None if not found + """ + step = getattr(event, "step", None) + if step is not None: + return str(step) + return None + + def finetune_price_to_dollars(price: float) -> float: + """Convert fine-tuning job price to dollars + + Args: + price (float): Fine-tuning job price in billing units + + Returns: + float: Price in dollars + """ return price / NANODOLLAR