-
Notifications
You must be signed in to change notification settings - Fork 21
Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) #262
Changes from all commits
3f1ec6a
ee7e02d
fdbdc8e
ee470fc
8f50eb8
b322858
24d209e
f1f8d9a
5320437
fbd17a6
bf0b180
7357926
5f00d95
3c72b9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ | |
| TrainingType, | ||
| FinetuneLRScheduler, | ||
| FinetuneLinearLRSchedulerArgs, | ||
| TrainingMethodDPO, | ||
| TrainingMethodSFT, | ||
| FinetuneCheckpoint, | ||
| ) | ||
| from together.types.finetune import ( | ||
|
|
@@ -39,6 +41,12 @@ | |
| _FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$" | ||
|
|
||
|
|
||
| AVAILABLE_TRAINING_METHODS = { | ||
| TrainingMethodSFT().method, | ||
| TrainingMethodDPO().method, | ||
| } | ||
|
|
||
|
|
||
| def createFinetuneRequest( | ||
| model_limits: FinetuneTrainingLimits, | ||
| training_file: str, | ||
|
|
@@ -64,8 +72,11 @@ def createFinetuneRequest( | |
| wandb_project_name: str | None = None, | ||
| wandb_name: str | None = None, | ||
| train_on_inputs: bool | Literal["auto"] = "auto", | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| from_checkpoint: str | None = None, | ||
| ) -> FinetuneRequest: | ||
|
|
||
| if batch_size == "max": | ||
| log_warn_once( | ||
| "Starting from together>=1.3.0, " | ||
|
|
@@ -113,11 +124,20 @@ def createFinetuneRequest( | |
| if weight_decay is not None and (weight_decay < 0): | ||
| raise ValueError("Weight decay should be non-negative") | ||
|
|
||
| if training_method not in AVAILABLE_TRAINING_METHODS: | ||
| raise ValueError( | ||
| f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}" | ||
| ) | ||
|
|
||
| lrScheduler = FinetuneLRScheduler( | ||
| lr_scheduler_type="linear", | ||
| lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | ||
| ) | ||
|
|
||
| training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: maybe annotate the type as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There were some issues with pre-commit checks when I tried to do this, as I remember
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future |
||
| if training_method == "dpo": | ||
mryab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) | ||
|
|
||
| finetune_request = FinetuneRequest( | ||
| model=model, | ||
| training_file=training_file, | ||
|
|
@@ -138,6 +158,7 @@ def createFinetuneRequest( | |
| wandb_project_name=wandb_project_name, | ||
| wandb_name=wandb_name, | ||
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method_cls, | ||
| from_checkpoint=from_checkpoint, | ||
| ) | ||
|
|
||
|
|
@@ -240,6 +261,8 @@ def create( | |
| verbose: bool = False, | ||
| model_limits: FinetuneTrainingLimits | None = None, | ||
| train_on_inputs: bool | Literal["auto"] = "auto", | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| from_checkpoint: str | None = None, | ||
| ) -> FinetuneResponse: | ||
| """ | ||
|
|
@@ -286,6 +309,9 @@ def create( | |
| For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields | ||
| (Instruction format), inputs will be masked. | ||
| Defaults to "auto". | ||
| training_method (str, optional): Training method. Defaults to "sft". | ||
| Supported methods: "sft", "dpo". | ||
| dpo_beta (float, optional): DPO beta parameter. Defaults to None. | ||
| from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job. | ||
| The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. | ||
| The step value is optional, without it the final checkpoint will be used. | ||
|
|
@@ -300,7 +326,6 @@ def create( | |
|
|
||
| if model_limits is None: | ||
| model_limits = self.get_model_limits(model=model) | ||
|
|
||
| finetune_request = createFinetuneRequest( | ||
| model_limits=model_limits, | ||
| training_file=training_file, | ||
|
|
@@ -326,6 +351,8 @@ def create( | |
| wandb_project_name=wandb_project_name, | ||
| wandb_name=wandb_name, | ||
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method, | ||
| dpo_beta=dpo_beta, | ||
| from_checkpoint=from_checkpoint, | ||
| ) | ||
|
|
||
|
|
@@ -344,7 +371,6 @@ def create( | |
| ), | ||
| stream=False, | ||
| ) | ||
|
|
||
| assert isinstance(response, TogetherResponse) | ||
|
|
||
| return FinetuneResponse(**response.data) | ||
|
|
@@ -608,6 +634,8 @@ async def create( | |
| verbose: bool = False, | ||
| model_limits: FinetuneTrainingLimits | None = None, | ||
| train_on_inputs: bool | Literal["auto"] = "auto", | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| from_checkpoint: str | None = None, | ||
| ) -> FinetuneResponse: | ||
| """ | ||
|
|
@@ -654,6 +682,9 @@ async def create( | |
| For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields | ||
| (Instruction format), inputs will be masked. | ||
| Defaults to "auto". | ||
| training_method (str, optional): Training method. Defaults to "sft". | ||
| Supported methods: "sft", "dpo". | ||
| dpo_beta (float, optional): DPO beta parameter. Defaults to None. | ||
| from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job. | ||
| The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. | ||
| The step value is optional, without it the final checkpoint will be used. | ||
|
|
@@ -694,6 +725,8 @@ async def create( | |
| wandb_project_name=wandb_project_name, | ||
| wandb_name=wandb_name, | ||
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method, | ||
| dpo_beta=dpo_beta, | ||
| from_checkpoint=from_checkpoint, | ||
| ) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.