Skip to content

Commit a4fd112

Browse files
VProvmryab
andauthored
Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) (#262)
* Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) --------- Co-authored-by: Max Ryabinin <[email protected]>
1 parent 271af98 commit a4fd112

File tree

7 files changed

+541
-68
lines changed

7 files changed

+541
-68
lines changed

src/together/cli/api/finetune.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,18 @@ def fine_tuning(ctx: click.Context) -> None:
110110
default="all-linear",
111111
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
112112
)
113+
@click.option(
114+
"--training-method",
115+
type=click.Choice(["sft", "dpo"]),
116+
default="sft",
117+
help="Training method to use. Options: sft (supervised fine-tuning), dpo (Direct Preference Optimization)",
118+
)
119+
@click.option(
120+
"--dpo-beta",
121+
type=float,
122+
default=0.1,
123+
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
124+
)
113125
@click.option(
114126
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
115127
)
@@ -166,6 +178,8 @@ def create(
166178
wandb_name: str,
167179
confirm: bool,
168180
train_on_inputs: bool | Literal["auto"],
181+
training_method: str,
182+
dpo_beta: float,
169183
from_checkpoint: str,
170184
) -> None:
171185
"""Start fine-tuning"""
@@ -195,6 +209,8 @@ def create(
195209
wandb_project_name=wandb_project_name,
196210
wandb_name=wandb_name,
197211
train_on_inputs=train_on_inputs,
212+
training_method=training_method,
213+
dpo_beta=dpo_beta,
198214
from_checkpoint=from_checkpoint,
199215
)
200216

src/together/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,18 @@ class DatasetFormat(enum.Enum):
3939
GENERAL = "general"
4040
CONVERSATION = "conversation"
4141
INSTRUCTION = "instruction"
42+
PREFERENCE_OPENAI = "preference_openai"
4243

4344

4445
JSONL_REQUIRED_COLUMNS_MAP = {
4546
DatasetFormat.GENERAL: ["text"],
4647
DatasetFormat.CONVERSATION: ["messages"],
4748
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
49+
DatasetFormat.PREFERENCE_OPENAI: [
50+
"input",
51+
"preferred_output",
52+
"non_preferred_output",
53+
],
4854
}
4955
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
5056
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]

src/together/resources/finetune.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
TrainingType,
2424
FinetuneLRScheduler,
2525
FinetuneLinearLRSchedulerArgs,
26+
TrainingMethodDPO,
27+
TrainingMethodSFT,
2628
FinetuneCheckpoint,
2729
)
2830
from together.types.finetune import (
@@ -39,6 +41,12 @@
3941
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
4042

4143

44+
AVAILABLE_TRAINING_METHODS = {
45+
TrainingMethodSFT().method,
46+
TrainingMethodDPO().method,
47+
}
48+
49+
4250
def createFinetuneRequest(
4351
model_limits: FinetuneTrainingLimits,
4452
training_file: str,
@@ -64,8 +72,11 @@ def createFinetuneRequest(
6472
wandb_project_name: str | None = None,
6573
wandb_name: str | None = None,
6674
train_on_inputs: bool | Literal["auto"] = "auto",
75+
training_method: str = "sft",
76+
dpo_beta: float | None = None,
6777
from_checkpoint: str | None = None,
6878
) -> FinetuneRequest:
79+
6980
if batch_size == "max":
7081
log_warn_once(
7182
"Starting from together>=1.3.0, "
@@ -113,11 +124,20 @@ def createFinetuneRequest(
113124
if weight_decay is not None and (weight_decay < 0):
114125
raise ValueError("Weight decay should be non-negative")
115126

127+
if training_method not in AVAILABLE_TRAINING_METHODS:
128+
raise ValueError(
129+
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
130+
)
131+
116132
lrScheduler = FinetuneLRScheduler(
117133
lr_scheduler_type="linear",
118134
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
119135
)
120136

137+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
138+
if training_method == "dpo":
139+
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
140+
121141
finetune_request = FinetuneRequest(
122142
model=model,
123143
training_file=training_file,
@@ -138,6 +158,7 @@ def createFinetuneRequest(
138158
wandb_project_name=wandb_project_name,
139159
wandb_name=wandb_name,
140160
train_on_inputs=train_on_inputs,
161+
training_method=training_method_cls,
141162
from_checkpoint=from_checkpoint,
142163
)
143164

@@ -240,6 +261,8 @@ def create(
240261
verbose: bool = False,
241262
model_limits: FinetuneTrainingLimits | None = None,
242263
train_on_inputs: bool | Literal["auto"] = "auto",
264+
training_method: str = "sft",
265+
dpo_beta: float | None = None,
243266
from_checkpoint: str | None = None,
244267
) -> FinetuneResponse:
245268
"""
@@ -286,6 +309,9 @@ def create(
286309
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
287310
(Instruction format), inputs will be masked.
288311
Defaults to "auto".
312+
training_method (str, optional): Training method. Defaults to "sft".
313+
Supported methods: "sft", "dpo".
314+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
289315
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
290316
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
291317
The step value is optional, without it the final checkpoint will be used.
@@ -300,7 +326,6 @@ def create(
300326

301327
if model_limits is None:
302328
model_limits = self.get_model_limits(model=model)
303-
304329
finetune_request = createFinetuneRequest(
305330
model_limits=model_limits,
306331
training_file=training_file,
@@ -326,6 +351,8 @@ def create(
326351
wandb_project_name=wandb_project_name,
327352
wandb_name=wandb_name,
328353
train_on_inputs=train_on_inputs,
354+
training_method=training_method,
355+
dpo_beta=dpo_beta,
329356
from_checkpoint=from_checkpoint,
330357
)
331358

@@ -344,7 +371,6 @@ def create(
344371
),
345372
stream=False,
346373
)
347-
348374
assert isinstance(response, TogetherResponse)
349375

350376
return FinetuneResponse(**response.data)
@@ -608,6 +634,8 @@ async def create(
608634
verbose: bool = False,
609635
model_limits: FinetuneTrainingLimits | None = None,
610636
train_on_inputs: bool | Literal["auto"] = "auto",
637+
training_method: str = "sft",
638+
dpo_beta: float | None = None,
611639
from_checkpoint: str | None = None,
612640
) -> FinetuneResponse:
613641
"""
@@ -654,6 +682,9 @@ async def create(
654682
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
655683
(Instruction format), inputs will be masked.
656684
Defaults to "auto".
685+
training_method (str, optional): Training method. Defaults to "sft".
686+
Supported methods: "sft", "dpo".
687+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
657688
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
658689
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
659690
The step value is optional, without it the final checkpoint will be used.
@@ -694,6 +725,8 @@ async def create(
694725
wandb_project_name=wandb_project_name,
695726
wandb_name=wandb_name,
696727
train_on_inputs=train_on_inputs,
728+
training_method=training_method,
729+
dpo_beta=dpo_beta,
697730
from_checkpoint=from_checkpoint,
698731
)
699732

src/together/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
FileType,
3232
)
3333
from together.types.finetune import (
34+
TrainingMethodDPO,
35+
TrainingMethodSFT,
3436
FinetuneCheckpoint,
3537
FinetuneDownloadResult,
3638
FinetuneLinearLRSchedulerArgs,
@@ -81,6 +83,8 @@
8183
"TrainingType",
8284
"FullTrainingType",
8385
"LoRATrainingType",
86+
"TrainingMethodDPO",
87+
"TrainingMethodSFT",
8488
"RerankRequest",
8589
"RerankResponse",
8690
"FinetuneTrainingLimits",

src/together/types/finetune.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
135135
type: str = "Lora"
136136

137137

138+
class TrainingMethod(BaseModel):
139+
"""
140+
Training method type
141+
"""
142+
143+
method: str
144+
145+
146+
class TrainingMethodSFT(TrainingMethod):
147+
"""
148+
Training method type for SFT training
149+
"""
150+
151+
method: Literal["sft"] = "sft"
152+
153+
154+
class TrainingMethodDPO(TrainingMethod):
155+
"""
156+
Training method type for DPO training
157+
"""
158+
159+
method: Literal["dpo"] = "dpo"
160+
dpo_beta: float | None = None
161+
162+
138163
class FinetuneRequest(BaseModel):
139164
"""
140165
Fine-tune request type
@@ -178,6 +203,10 @@ class FinetuneRequest(BaseModel):
178203
training_type: FullTrainingType | LoRATrainingType | None = None
179204
# train on inputs
180205
train_on_inputs: StrictBool | Literal["auto"] = "auto"
206+
# training method
207+
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208+
default_factory=TrainingMethodSFT
209+
)
181210
# from step
182211
from_checkpoint: str
183212

0 commit comments

Comments
 (0)