Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,8 @@ def create(
raise click.BadParameter(
f"LoRA fine-tuning is not supported for the model `{model}`"
)
if training_method == "dpo":
default_batch_size = model_limits.lora_training.max_batch_size_dpo
else:
default_batch_size = model_limits.lora_training.max_batch_size
default_values = {
"lora_r": model_limits.lora_training.max_rank,
"batch_size": default_batch_size,
"learning_rate": 1e-3,
}

Expand All @@ -335,15 +330,6 @@ def create(
f"Please change the job type with --lora or remove `{param}` from the arguments"
)

batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
if batch_size_source == ParameterSource.DEFAULT:
if training_method == "dpo":
training_args["batch_size"] = (
model_limits.full_training.max_batch_size_dpo
)
else:
training_args["batch_size"] = model_limits.full_training.max_batch_size

if n_evals <= 0 and validation_file:
log_warn(
"Warning: You have specified a validation file but the number of evaluation loops is set to 0. No evaluations will be performed."
Expand Down
4 changes: 2 additions & 2 deletions src/together/legacy/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def create(
model: str,
n_epochs: int = 1,
n_checkpoints: int | None = 1,
batch_size: int | None = 32,
batch_size: int | Literal["max"] = 32,
learning_rate: float = 0.00001,
suffix: (
str | None
Expand All @@ -43,7 +43,7 @@ def create(
model=model,
n_epochs=n_epochs,
n_checkpoints=n_checkpoints,
batch_size=batch_size if isinstance(batch_size, int) else "max",
batch_size=batch_size,
learning_rate=learning_rate,
suffix=suffix,
wandb_api_key=wandb_api_key,
Expand Down
31 changes: 13 additions & 18 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,22 @@ def create_finetune_request(
min_batch_size = model_limits.full_training.min_batch_size
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo

if batch_size == "max":
if training_method == "dpo":
batch_size = max_batch_size_dpo
else:
batch_size = max_batch_size
if batch_size != "max":
if training_method == "sft":
if batch_size > max_batch_size:
raise ValueError(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
)
elif training_method == "dpo":
if batch_size > max_batch_size_dpo:
raise ValueError(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
)

if training_method == "sft":
if batch_size > max_batch_size:
if batch_size < min_batch_size:
raise ValueError(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
)
elif training_method == "dpo":
if batch_size > max_batch_size_dpo:
raise ValueError(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
)

if batch_size < min_batch_size:
raise ValueError(
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
)

if warmup_ratio > 1 or warmup_ratio < 0:
raise ValueError(f"Warmup ratio should be between 0 and 1 (got {warmup_ratio})")
Expand Down
2 changes: 1 addition & 1 deletion src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class FinetuneRequest(BaseModel):
# number of evaluation loops to run
n_evals: int | None = None
# training batch size
batch_size: int | None = None
batch_size: int | Literal["max"] | None = None
# up to 40 character suffix for output model name
suffix: str | None = None
# weights & biases api key
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_finetune_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_simple_request():
assert request.n_epochs > 0
assert request.warmup_ratio == 0.0
assert request.training_type.type == "Full"
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size
assert request.batch_size == "max"


def test_validation_file():
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_lora_request():
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
assert request.training_type.lora_dropout == 0.0
assert request.training_type.lora_trainable_modules == "all-linear"
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size
assert request.batch_size == "max"


@pytest.mark.parametrize("lora_dropout", [-1, 0, 0.5, 1.0, 10.0])
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_dpo_request_lora():
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
assert request.training_type.lora_dropout == 0.0
assert request.training_type.lora_trainable_modules == "all-linear"
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size_dpo
assert request.batch_size == "max"


def test_dpo_request():
Expand All @@ -137,7 +137,7 @@ def test_dpo_request():
)

assert request.training_type.type == "Full"
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size_dpo
assert request.batch_size == "max"


def test_from_checkpoint_request():
Expand Down
Loading