@@ -69,7 +69,7 @@ def create_finetune_request(
6969 wandb_base_url : str | None = None ,
7070 wandb_project_name : str | None = None ,
7171 wandb_name : str | None = None ,
72- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
72+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
7373 training_method : str = "sft" ,
7474 dpo_beta : float | None = None ,
7575 from_checkpoint : str | None = None ,
@@ -166,6 +166,15 @@ def create_finetune_request(
166166 f"training_method must be one of { ', ' .join (AVAILABLE_TRAINING_METHODS )} "
167167 )
168168
169+ if train_on_inputs is not None and training_method != "sft" :
170+ raise ValueError ("train_on_inputs is only supported for SFT training" )
171+
172+ if train_on_inputs is None and training_method == "sft" :
173+ log_warn_once (
174+ "train_on_inputs is not set for SFT training, it will be set to 'auto' automatically"
175+ )
176+ train_on_inputs = "auto"
177+
169178 lr_scheduler : FinetuneLRScheduler
170179 if lr_scheduler_type == "cosine" :
171180 if scheduler_num_cycles <= 0.0 :
@@ -183,7 +192,9 @@ def create_finetune_request(
183192 lr_scheduler_args = LinearLRSchedulerArgs (min_lr_ratio = min_lr_ratio ),
184193 )
185194
186- training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT ()
195+ training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT (
196+ train_on_inputs = train_on_inputs
197+ )
187198 if training_method == "dpo" :
188199 training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
189200
@@ -206,7 +217,6 @@ def create_finetune_request(
206217 wandb_base_url = wandb_base_url ,
207218 wandb_project_name = wandb_project_name ,
208219 wandb_name = wandb_name ,
209- train_on_inputs = train_on_inputs ,
210220 training_method = training_method_cls ,
211221 from_checkpoint = from_checkpoint ,
212222 )
@@ -281,7 +291,7 @@ def create(
281291 wandb_name : str | None = None ,
282292 verbose : bool = False ,
283293 model_limits : FinetuneTrainingLimits | None = None ,
284- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
294+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
285295 training_method : str = "sft" ,
286296 dpo_beta : float | None = None ,
287297 from_checkpoint : str | None = None ,
@@ -326,12 +336,12 @@ def create(
326336 Defaults to False.
327337 model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
328338 Defaults to None.
329- train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
339+ train_on_inputs (bool or "auto", optional ): Whether to mask the user messages in conversational data or prompts in instruction data.
330340 "auto" will automatically determine whether to mask the inputs based on the data format.
331341 For datasets with the "text" field (general format), inputs will not be masked.
332342 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
333343 (Instruction format), inputs will be masked.
334- Defaults to "auto".
344+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
335345 training_method (str, optional): Training method. Defaults to "sft".
336346 Supported methods: "sft", "dpo".
337347 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -693,7 +703,7 @@ async def create(
693703 wandb_name : str | None = None ,
694704 verbose : bool = False ,
695705 model_limits : FinetuneTrainingLimits | None = None ,
696- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
706+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
697707 training_method : str = "sft" ,
698708 dpo_beta : float | None = None ,
699709 from_checkpoint : str | None = None ,
@@ -743,7 +753,7 @@ async def create(
743753 For datasets with the "text" field (general format), inputs will not be masked.
744754 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
745755 (Instruction format), inputs will be masked.
746- Defaults to "auto".
756+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
747757 training_method (str, optional): Training method. Defaults to "sft".
748758 Supported methods: "sft", "dpo".
749759 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
0 commit comments