Skip to content

Commit f759400

Browse files
committed
Add validation for lora_dropout
1 parent 554d3e8 commit f759400

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/together/resources/finetune.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ def create_finetune_request(
101101
raise ValueError(
102102
f"LoRA adapters are not supported for the selected model ({model_or_checkpoint})."
103103
)
104+
105+
if not 0 <= lora_dropout < 1.0:
106+
raise ValueError("LoRA dropout must be in [0, 1) range.")
107+
104108
lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank
105109
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
106110
training_type = LoRATrainingType(

tests/unit/test_finetune_resources.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ def test_lora_request():
8585
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size
8686

8787

88+
@pytest.mark.parametrize("lora_dropout", [-1, 0, 0.5, 1.0, 10.0])
89+
def test_lora_request_with_lora_dropout(lora_dropout: float):
90+
91+
if 0 <= lora_dropout < 1:
92+
request = create_finetune_request(
93+
model_limits=_MODEL_LIMITS,
94+
model=_MODEL_NAME,
95+
training_file=_TRAINING_FILE,
96+
lora=True,
97+
lora_dropout=lora_dropout,
98+
)
99+
assert request.training_type.lora_dropout == lora_dropout
100+
else:
101+
with pytest.raises(
102+
ValueError,
103+
match=r"LoRA dropout must be in \[0, 1\) range.",
104+
):
105+
create_finetune_request(
106+
model_limits=_MODEL_LIMITS,
107+
model=_MODEL_NAME,
108+
training_file=_TRAINING_FILE,
109+
lora=True,
110+
lora_dropout=lora_dropout,
111+
)
112+
113+
88114
def test_dpo_request_lora():
89115
request = create_finetune_request(
90116
model_limits=_MODEL_LIMITS,

0 commit comments

Comments
 (0)