diff --git a/pyproject.toml b/pyproject.toml index 4e7109c4..98b8964e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.5.10" +version = "1.5.11" authors = ["Together AI "] description = "Python client for Together's Cloud Platform!" readme = "README.md" diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 0bc47c0d..4e05a772 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -101,6 +101,11 @@ def create_finetune_request( raise ValueError( f"LoRA adapters are not supported for the selected model ({model_or_checkpoint})." ) + + if lora_dropout is not None: + if not 0 <= lora_dropout < 1.0: + raise ValueError("LoRA dropout must be in [0, 1) range.") + lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2 training_type = LoRATrainingType( diff --git a/tests/unit/test_finetune_resources.py b/tests/unit/test_finetune_resources.py index f354b2f9..5ddbe0ab 100644 --- a/tests/unit/test_finetune_resources.py +++ b/tests/unit/test_finetune_resources.py @@ -85,6 +85,32 @@ def test_lora_request(): assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size +@pytest.mark.parametrize("lora_dropout", [-1, 0, 0.5, 1.0, 10.0]) +def test_lora_request_with_lora_dropout(lora_dropout: float): + + if 0 <= lora_dropout < 1: + request = create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + lora=True, + lora_dropout=lora_dropout, + ) + assert request.training_type.lora_dropout == lora_dropout + else: + with pytest.raises( + ValueError, + match=r"LoRA dropout must be in \[0, 1\) range.", + ): + create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + lora=True, + lora_dropout=lora_dropout, + ) + + def test_dpo_request_lora(): request = create_finetune_request( model_limits=_MODEL_LIMITS,