diff --git a/pyproject.toml b/pyproject.toml index 26ddfad4..94226c18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.5.6" +version = "1.5.7" authors = ["Together AI "] description = "Python client for Together's Cloud Platform!" readme = "README.md" diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 467c8296..751fe6ae 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -258,10 +258,13 @@ def create( raise click.BadParameter( f"LoRA fine-tuning is not supported for the model `{model}`" ) - + if training_method == "dpo": + default_batch_size = model_limits.lora_training.max_batch_size_dpo + else: + default_batch_size = model_limits.lora_training.max_batch_size default_values = { "lora_r": model_limits.lora_training.max_rank, - "batch_size": model_limits.lora_training.max_batch_size, + "batch_size": default_batch_size, "learning_rate": 1e-3, } @@ -288,7 +291,12 @@ def create( batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined] if batch_size_source == ParameterSource.DEFAULT: - training_args["batch_size"] = model_limits.full_training.max_batch_size + if training_method == "dpo": + training_args["batch_size"] = ( + model_limits.full_training.max_batch_size_dpo + ) + else: + training_args["batch_size"] = model_limits.full_training.max_batch_size if n_evals <= 0 and validation_file: log_warn( diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 82a70211..275d6839 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -102,6 +102,7 @@ def create_finetune_request( training_type: TrainingType = FullTrainingType() max_batch_size: int = 0 + max_batch_size_dpo: int = 0 min_batch_size: int = 0 if lora: if model_limits.lora_training is None: @@ -119,7 +120,7 @@ def create_finetune_request( max_batch_size = model_limits.lora_training.max_batch_size min_batch_size = model_limits.lora_training.min_batch_size - + max_batch_size_dpo = model_limits.lora_training.max_batch_size_dpo else: if model_limits.full_training is None: raise ValueError( @@ -128,13 +129,24 @@ def create_finetune_request( max_batch_size = model_limits.full_training.max_batch_size min_batch_size = model_limits.full_training.min_batch_size + max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo - batch_size = batch_size if batch_size != "max" else max_batch_size + if batch_size == "max": + if training_method == "dpo": + batch_size = max_batch_size_dpo + else: + batch_size = max_batch_size - if batch_size > max_batch_size: - raise ValueError( - f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}." - ) + if training_method == "sft": + if batch_size > max_batch_size: + raise ValueError( + f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}." + ) + elif training_method == "dpo": + if batch_size > max_batch_size_dpo: + raise ValueError( + f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}." + ) if batch_size < min_batch_size: raise ValueError( diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 6325ce59..5c2c2c2f 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import List, Literal +from typing import List, Literal, Any from pydantic import StrictBool, Field, field_validator @@ -329,8 +329,16 @@ class FinetuneDownloadResult(BaseModel): class FinetuneFullTrainingLimits(BaseModel): max_batch_size: int + max_batch_size_dpo: int = -1 min_batch_size: int + def __init__(self, **data: Any) -> None: + super().__init__(**data) + if self.max_batch_size_dpo == -1: + half_max = self.max_batch_size // 2 + rounded_half_max = (half_max // 8) * 8 + self.max_batch_size_dpo = max(self.min_batch_size, rounded_half_max) + class FinetuneLoraTrainingLimits(FinetuneFullTrainingLimits): max_rank: int diff --git a/tests/unit/test_finetune_resources.py b/tests/unit/test_finetune_resources.py index f7acdbca..7ff69415 100644 --- a/tests/unit/test_finetune_resources.py +++ b/tests/unit/test_finetune_resources.py @@ -18,10 +18,12 @@ min_learning_rate=1e-6, full_training=FinetuneFullTrainingLimits( max_batch_size=96, + max_batch_size_dpo=48, min_batch_size=8, ), lora_training=FinetuneLoraTrainingLimits( max_batch_size=128, + max_batch_size_dpo=64, min_batch_size=8, max_rank=64, target_modules=["q", "k", "v", "o", "mlp"], @@ -83,6 +85,36 @@ def test_lora_request(): assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size +def test_dpo_request_lora(): + request = create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + training_method="dpo", + lora=True, + ) + + assert request.training_type.type == "Lora" + assert request.training_type.lora_r == _MODEL_LIMITS.lora_training.max_rank + assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2 + assert request.training_type.lora_dropout == 0.0 + assert request.training_type.lora_trainable_modules == "all-linear" + assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size_dpo + + +def test_dpo_request(): + request = create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + training_method="dpo", + lora=False, + ) + + assert request.training_type.type == "Full" + assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size_dpo + + def test_from_checkpoint_request(): request = create_finetune_request( model_limits=_MODEL_LIMITS, @@ -160,6 +192,7 @@ def test_non_lora_model(): min_learning_rate=1e-6, full_training=FinetuneFullTrainingLimits( max_batch_size=96, + max_batch_size_dpo=48, min_batch_size=8, ), lora_training=None, @@ -181,6 +214,7 @@ def test_non_full_model(): min_learning_rate=1e-6, lora_training=FinetuneLoraTrainingLimits( max_batch_size=96, + max_batch_size_dpo=48, min_batch_size=8, max_rank=64, target_modules=["q", "k", "v", "o", "mlp"],