-
Notifications
You must be signed in to change notification settings - Fork 21
Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) #262
Changes from 4 commits
3f1ec6a
ee7e02d
fdbdc8e
ee470fc
8f50eb8
b322858
24d209e
f1f8d9a
5320437
fbd17a6
bf0b180
7357926
5f00d95
3c72b9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
| from typing import Literal | ||
| from typing import Literal, Union | ||
|
|
||
| from rich import print as rprint | ||
|
|
||
|
|
@@ -22,6 +22,8 @@ | |
| TrainingType, | ||
| FinetuneLRScheduler, | ||
| FinetuneLinearLRSchedulerArgs, | ||
| TrainingMethodDPO, | ||
| TrainingMethodSFT, | ||
| ) | ||
| from together.types.finetune import DownloadCheckpointType | ||
| from together.utils import log_warn_once, normalize_key | ||
|
|
@@ -52,6 +54,8 @@ def createFinetuneRequest( | |
| wandb_project_name: str | None = None, | ||
| wandb_name: str | None = None, | ||
| train_on_inputs: bool | Literal["auto"] = "auto", | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| ) -> FinetuneRequest: | ||
| if batch_size == "max": | ||
| log_warn_once( | ||
|
|
@@ -105,6 +109,12 @@ def createFinetuneRequest( | |
| lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | ||
| ) | ||
|
|
||
| training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = ( | ||
|
||
| TrainingMethodSFT() | ||
| ) | ||
| if training_method == "dpo": | ||
mryab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) | ||
|
|
||
| finetune_request = FinetuneRequest( | ||
| model=model, | ||
| training_file=training_file, | ||
|
|
@@ -125,6 +135,7 @@ def createFinetuneRequest( | |
| wandb_project_name=wandb_project_name, | ||
| wandb_name=wandb_name, | ||
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method_cls, | ||
| ) | ||
|
|
||
| return finetune_request | ||
|
|
@@ -162,6 +173,8 @@ def create( | |
| verbose: bool = False, | ||
| model_limits: FinetuneTrainingLimits | None = None, | ||
| train_on_inputs: bool | Literal["auto"] = "auto", | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| ) -> FinetuneResponse: | ||
| """ | ||
| Method to initiate a fine-tuning job | ||
|
|
@@ -207,6 +220,9 @@ def create( | |
| For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields | ||
| (Instruction format), inputs will be masked. | ||
| Defaults to "auto". | ||
| training_method (str, optional): Training method. Defaults to "sft". | ||
| Supported methods: "sft", "dpo". | ||
| dpo_beta (float, optional): DPO beta parameter. Defaults to None. | ||
|
|
||
| Returns: | ||
| FinetuneResponse: Object containing information about fine-tuning job. | ||
|
|
@@ -218,7 +234,6 @@ def create( | |
|
|
||
| if model_limits is None: | ||
| model_limits = self.get_model_limits(model=model) | ||
|
|
||
| finetune_request = createFinetuneRequest( | ||
| model_limits=model_limits, | ||
| training_file=training_file, | ||
|
|
@@ -244,6 +259,8 @@ def create( | |
| wandb_project_name=wandb_project_name, | ||
| wandb_name=wandb_name, | ||
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method, | ||
| dpo_beta=dpo_beta, | ||
| ) | ||
|
|
||
| if verbose: | ||
|
|
@@ -261,7 +278,6 @@ def create( | |
| ), | ||
| stream=False, | ||
| ) | ||
|
|
||
| assert isinstance(response, TogetherResponse) | ||
|
|
||
| return FinetuneResponse(**response.data) | ||
|
|
@@ -503,6 +519,8 @@ async def create( | |
| verbose: bool = False, | ||
| model_limits: FinetuneTrainingLimits | None = None, | ||
| train_on_inputs: bool | Literal["auto"] = "auto", | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| ) -> FinetuneResponse: | ||
| """ | ||
| Async method to initiate a fine-tuning job | ||
|
|
@@ -548,6 +566,9 @@ async def create( | |
| For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields | ||
| (Instruction format), inputs will be masked. | ||
| Defaults to "auto". | ||
| training_method (str, optional): Training method. Defaults to "sft". | ||
| Supported methods: "sft", "dpo". | ||
| dpo_beta (float, optional): DPO beta parameter. Defaults to None. | ||
|
|
||
| Returns: | ||
| FinetuneResponse: Object containing information about fine-tuning job. | ||
|
|
@@ -585,6 +606,8 @@ async def create( | |
| wandb_project_name=wandb_project_name, | ||
| wandb_name=wandb_name, | ||
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method, | ||
| dpo_beta=dpo_beta, | ||
| ) | ||
|
|
||
| if verbose: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from enum import Enum | ||||||
| from typing import List, Literal | ||||||
| from typing import List, Literal, Union | ||||||
|
|
||||||
| from pydantic import StrictBool, Field, validator, field_validator | ||||||
|
|
||||||
|
|
@@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType): | |||||
| type: str = "Lora" | ||||||
|
|
||||||
|
|
||||||
| class TrainingMethod(BaseModel): | ||||||
| """ | ||||||
| Training method type | ||||||
| """ | ||||||
|
|
||||||
| method: str | ||||||
|
|
||||||
|
|
||||||
| class TrainingMethodSFT(TrainingMethod): | ||||||
| """ | ||||||
| Training method type for SFT training | ||||||
| """ | ||||||
|
|
||||||
| method: str = "sft" | ||||||
|
||||||
| method: str = "sft" | |
| method: Literal["sft"] = "sft" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
Uh oh!
There was an error while loading. Please reload this page.