diff --git a/pyproject.toml b/pyproject.toml index 4a7f4c99..82f8090c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.5.0" +version = "1.5.1" authors = [ "Together AI " ] diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index b413f323..566896bd 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -60,7 +60,7 @@ def fine_tuning(ctx: click.Context) -> None: @click.option( "--training-file", type=str, required=True, help="Training file ID from Files API" ) -@click.option("--model", type=str, required=True, help="Base model name") +@click.option("--model", type=str, help="Base model name") @click.option("--n-epochs", type=int, default=1, help="Number of epochs to train for") @click.option( "--validation-file", type=str, default="", help="Validation file ID from Files API" @@ -214,8 +214,15 @@ def create( from_checkpoint=from_checkpoint, ) + if model is None and from_checkpoint is None: + raise click.BadParameter("You must specify either a model or a checkpoint") + + model_name = model + if from_checkpoint is not None: + model_name = from_checkpoint.split(":")[0] + model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits( - model=model + model=model_name ) if lora: diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 8cc48a17..08710750 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -50,7 +50,7 @@ def createFinetuneRequest( model_limits: FinetuneTrainingLimits, training_file: str, - model: str, + model: str | None = None, n_epochs: int = 1, validation_file: str | None = "", n_evals: int | None = 0, @@ -77,6 +77,11 @@ def createFinetuneRequest( from_checkpoint: str | None = None, ) -> FinetuneRequest: + if model is not None and from_checkpoint is not None: + raise ValueError( + "You must specify either a model or a checkpoint to start a job from, not both" + ) + if batch_size == "max": log_warn_once( "Starting from together>=1.3.0, " @@ -237,7 +242,7 @@ def create( self, *, training_file: str, - model: str, + model: str | None = None, n_epochs: int = 1, validation_file: str | None = "", n_evals: int | None = 0, @@ -270,7 +275,7 @@ def create( Args: training_file (str): File-ID of a file uploaded to the Together API - model (str): Name of the base model to run fine-tune job on + model (str, optional): Name of the base model to run fine-tune job on n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. validation file (str, optional): File ID of a file uploaded to the Together API for validation. n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. @@ -320,12 +325,24 @@ def create( FinetuneResponse: Object containing information about fine-tuning job. """ + if model is None and from_checkpoint is None: + raise ValueError("You must specify either a model or a checkpoint") + requestor = api_requestor.APIRequestor( client=self._client, ) if model_limits is None: - model_limits = self.get_model_limits(model=model) + # mypy doesn't understand that model or from_checkpoint is not None + if model is not None: + model_name = model + elif from_checkpoint is not None: + model_name = from_checkpoint.split(":")[0] + else: + # this branch is unreachable, but mypy doesn't know that + pass + model_limits = self.get_model_limits(model=model_name) + finetune_request = createFinetuneRequest( model_limits=model_limits, training_file=training_file, @@ -610,7 +627,7 @@ async def create( self, *, training_file: str, - model: str, + model: str | None = None, n_epochs: int = 1, validation_file: str | None = "", n_evals: int | None = 0, @@ -643,7 +660,7 @@ async def create( Args: training_file (str): File-ID of a file uploaded to the Together API - model (str): Name of the base model to run fine-tune job on + model (str, optional): Name of the base model to run fine-tune job on n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. validation file (str, optional): File ID of a file uploaded to the Together API for validation. n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. @@ -693,12 +710,23 @@ async def create( FinetuneResponse: Object containing information about fine-tuning job. """ + if model is None and from_checkpoint is None: + raise ValueError("You must specify either a model or a checkpoint") + requestor = api_requestor.APIRequestor( client=self._client, ) if model_limits is None: - model_limits = await self.get_model_limits(model=model) + # mypy doesn't understand that model or from_checkpoint is not None + if model is not None: + model_name = model + elif from_checkpoint is not None: + model_name = from_checkpoint.split(":")[0] + else: + # this branch is unreachable, but mypy doesn't know that + pass + model_limits = await self.get_model_limits(model=model_name) finetune_request = createFinetuneRequest( model_limits=model_limits,