Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.5.6"
version = "1.5.7"
authors = ["Together AI <[email protected]>"]
description = "Python client for Together's Cloud Platform!"
readme = "README.md"
Expand Down
14 changes: 11 additions & 3 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,13 @@ def create(
raise click.BadParameter(
f"LoRA fine-tuning is not supported for the model `{model}`"
)

if training_method == "dpo":
default_batch_size = model_limits.lora_training.max_batch_size_dpo
else:
default_batch_size = model_limits.lora_training.max_batch_size
default_values = {
"lora_r": model_limits.lora_training.max_rank,
"batch_size": model_limits.lora_training.max_batch_size,
"batch_size": default_batch_size,
"learning_rate": 1e-3,
}

Expand All @@ -288,7 +291,12 @@ def create(

batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
if batch_size_source == ParameterSource.DEFAULT:
training_args["batch_size"] = model_limits.full_training.max_batch_size
if training_method == "dpo":
training_args["batch_size"] = (
model_limits.full_training.max_batch_size_dpo
)
else:
training_args["batch_size"] = model_limits.full_training.max_batch_size

if n_evals <= 0 and validation_file:
log_warn(
Expand Down
24 changes: 18 additions & 6 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def create_finetune_request(

training_type: TrainingType = FullTrainingType()
max_batch_size: int = 0
max_batch_size_dpo: int = 0
min_batch_size: int = 0
if lora:
if model_limits.lora_training is None:
Expand All @@ -119,7 +120,7 @@ def create_finetune_request(

max_batch_size = model_limits.lora_training.max_batch_size
min_batch_size = model_limits.lora_training.min_batch_size

max_batch_size_dpo = model_limits.lora_training.max_batch_size_dpo
else:
if model_limits.full_training is None:
raise ValueError(
Expand All @@ -128,13 +129,24 @@ def create_finetune_request(

max_batch_size = model_limits.full_training.max_batch_size
min_batch_size = model_limits.full_training.min_batch_size
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo

batch_size = batch_size if batch_size != "max" else max_batch_size
if batch_size == "max":
if training_method == "dpo":
batch_size = max_batch_size_dpo
else:
batch_size = max_batch_size

if batch_size > max_batch_size:
raise ValueError(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
)
if training_method == "sft":
if batch_size > max_batch_size:
raise ValueError(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
)
elif training_method == "dpo":
if batch_size > max_batch_size_dpo:
raise ValueError(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
)

if batch_size < min_batch_size:
raise ValueError(
Expand Down
10 changes: 9 additions & 1 deletion src/together/types/finetune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import List, Literal
from typing import List, Literal, Any

from pydantic import StrictBool, Field, field_validator

Expand Down Expand Up @@ -329,8 +329,16 @@ class FinetuneDownloadResult(BaseModel):

class FinetuneFullTrainingLimits(BaseModel):
max_batch_size: int
max_batch_size_dpo: int = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(optional)
I prefer to have int | None rather than == -1, but it still works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of problems with type check in our code this way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you

min_batch_size: int

def __init__(self, **data: Any) -> None:
super().__init__(**data)
if self.max_batch_size_dpo == -1:
half_max = self.max_batch_size // 2
rounded_half_max = (half_max // 8) * 8
self.max_batch_size_dpo = max(self.min_batch_size, rounded_half_max)


class FinetuneLoraTrainingLimits(FinetuneFullTrainingLimits):
max_rank: int
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/test_finetune_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
min_learning_rate=1e-6,
full_training=FinetuneFullTrainingLimits(
max_batch_size=96,
max_batch_size_dpo=48,
min_batch_size=8,
),
lora_training=FinetuneLoraTrainingLimits(
max_batch_size=128,
max_batch_size_dpo=64,
min_batch_size=8,
max_rank=64,
target_modules=["q", "k", "v", "o", "mlp"],
Expand Down Expand Up @@ -83,6 +85,36 @@ def test_lora_request():
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size


def test_dpo_request_lora():
request = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
training_method="dpo",
lora=True,
)

assert request.training_type.type == "Lora"
assert request.training_type.lora_r == _MODEL_LIMITS.lora_training.max_rank
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
assert request.training_type.lora_dropout == 0.0
assert request.training_type.lora_trainable_modules == "all-linear"
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size_dpo


def test_dpo_request():
request = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
training_method="dpo",
lora=False,
)

assert request.training_type.type == "Full"
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size_dpo


def test_from_checkpoint_request():
request = create_finetune_request(
model_limits=_MODEL_LIMITS,
Expand Down Expand Up @@ -160,6 +192,7 @@ def test_non_lora_model():
min_learning_rate=1e-6,
full_training=FinetuneFullTrainingLimits(
max_batch_size=96,
max_batch_size_dpo=48,
min_batch_size=8,
),
lora_training=None,
Expand All @@ -181,6 +214,7 @@ def test_non_full_model():
min_learning_rate=1e-6,
lora_training=FinetuneLoraTrainingLimits(
max_batch_size=96,
max_batch_size_dpo=48,
min_batch_size=8,
max_rank=64,
target_modules=["q", "k", "v", "o", "mlp"],
Expand Down