2323 TrainingType ,
2424 FinetuneLRScheduler ,
2525 FinetuneLinearLRSchedulerArgs ,
26+ TrainingMethodDPO ,
27+ TrainingMethodSFT ,
2628 FinetuneCheckpoint ,
2729)
2830from together .types .finetune import (
3941_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
4042
4143
44+ AVAILABLE_TRAINING_METHODS = {
45+ TrainingMethodSFT ().method ,
46+ TrainingMethodDPO ().method ,
47+ }
48+
49+
4250def createFinetuneRequest (
4351 model_limits : FinetuneTrainingLimits ,
4452 training_file : str ,
@@ -64,8 +72,11 @@ def createFinetuneRequest(
6472 wandb_project_name : str | None = None ,
6573 wandb_name : str | None = None ,
6674 train_on_inputs : bool | Literal ["auto" ] = "auto" ,
75+ training_method : str = "sft" ,
76+ dpo_beta : float | None = None ,
6777 from_checkpoint : str | None = None ,
6878) -> FinetuneRequest :
79+
6980 if batch_size == "max" :
7081 log_warn_once (
7182 "Starting from together>=1.3.0, "
@@ -113,11 +124,20 @@ def createFinetuneRequest(
113124 if weight_decay is not None and (weight_decay < 0 ):
114125 raise ValueError ("Weight decay should be non-negative" )
115126
127+ if training_method not in AVAILABLE_TRAINING_METHODS :
128+ raise ValueError (
129+ f"training_method must be one of { ', ' .join (AVAILABLE_TRAINING_METHODS )} "
130+ )
131+
116132 lrScheduler = FinetuneLRScheduler (
117133 lr_scheduler_type = "linear" ,
118134 lr_scheduler_args = FinetuneLinearLRSchedulerArgs (min_lr_ratio = min_lr_ratio ),
119135 )
120136
137+ training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT ()
138+ if training_method == "dpo" :
139+ training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
140+
121141 finetune_request = FinetuneRequest (
122142 model = model ,
123143 training_file = training_file ,
@@ -138,6 +158,7 @@ def createFinetuneRequest(
138158 wandb_project_name = wandb_project_name ,
139159 wandb_name = wandb_name ,
140160 train_on_inputs = train_on_inputs ,
161+ training_method = training_method_cls ,
141162 from_checkpoint = from_checkpoint ,
142163 )
143164
@@ -240,6 +261,8 @@ def create(
240261 verbose : bool = False ,
241262 model_limits : FinetuneTrainingLimits | None = None ,
242263 train_on_inputs : bool | Literal ["auto" ] = "auto" ,
264+ training_method : str = "sft" ,
265+ dpo_beta : float | None = None ,
243266 from_checkpoint : str | None = None ,
244267 ) -> FinetuneResponse :
245268 """
@@ -286,6 +309,9 @@ def create(
286309 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
287310 (Instruction format), inputs will be masked.
288311 Defaults to "auto".
312+ training_method (str, optional): Training method. Defaults to "sft".
313+ Supported methods: "sft", "dpo".
314+ dpo_beta (float, optional): DPO beta parameter. Defaults to None.
289315 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
290316 The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
291317 The step value is optional, without it the final checkpoint will be used.
@@ -300,7 +326,6 @@ def create(
300326
301327 if model_limits is None :
302328 model_limits = self .get_model_limits (model = model )
303-
304329 finetune_request = createFinetuneRequest (
305330 model_limits = model_limits ,
306331 training_file = training_file ,
@@ -326,6 +351,8 @@ def create(
326351 wandb_project_name = wandb_project_name ,
327352 wandb_name = wandb_name ,
328353 train_on_inputs = train_on_inputs ,
354+ training_method = training_method ,
355+ dpo_beta = dpo_beta ,
329356 from_checkpoint = from_checkpoint ,
330357 )
331358
@@ -344,7 +371,6 @@ def create(
344371 ),
345372 stream = False ,
346373 )
347-
348374 assert isinstance (response , TogetherResponse )
349375
350376 return FinetuneResponse (** response .data )
@@ -608,6 +634,8 @@ async def create(
608634 verbose : bool = False ,
609635 model_limits : FinetuneTrainingLimits | None = None ,
610636 train_on_inputs : bool | Literal ["auto" ] = "auto" ,
637+ training_method : str = "sft" ,
638+ dpo_beta : float | None = None ,
611639 from_checkpoint : str | None = None ,
612640 ) -> FinetuneResponse :
613641 """
@@ -654,6 +682,9 @@ async def create(
654682 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
655683 (Instruction format), inputs will be masked.
656684 Defaults to "auto".
685+ training_method (str, optional): Training method. Defaults to "sft".
686+ Supported methods: "sft", "dpo".
687+ dpo_beta (float, optional): DPO beta parameter. Defaults to None.
657688 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
658689 The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
659690 The step value is optional, without it the final checkpoint will be used.
@@ -694,6 +725,8 @@ async def create(
694725 wandb_project_name = wandb_project_name ,
695726 wandb_name = wandb_name ,
696727 train_on_inputs = train_on_inputs ,
728+ training_method = training_method ,
729+ dpo_beta = dpo_beta ,
697730 from_checkpoint = from_checkpoint ,
698731 )
699732
0 commit comments