@@ -77,7 +77,7 @@ def create_finetune_request(
7777 wandb_base_url : str | None = None ,
7878 wandb_project_name : str | None = None ,
7979 wandb_name : str | None = None ,
80- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
80+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
8181 training_method : str = "sft" ,
8282 dpo_beta : float | None = None ,
8383 from_checkpoint : str | None = None ,
@@ -174,6 +174,15 @@ def create_finetune_request(
174174 f"training_method must be one of { ', ' .join (AVAILABLE_TRAINING_METHODS )} "
175175 )
176176
177+ if train_on_inputs is not None and training_method != "sft" :
178+ raise ValueError ("train_on_inputs is only supported for SFT training" )
179+
180+ if train_on_inputs is None and training_method == "sft" :
181+ log_warn_once (
182+ "train_on_inputs is not set for SFT training, it will be set to 'auto' automatically"
183+ )
184+ train_on_inputs = "auto"
185+
177186 lr_scheduler : FinetuneLRScheduler
178187 if lr_scheduler_type == "cosine" :
179188 if scheduler_num_cycles <= 0.0 :
@@ -191,7 +200,9 @@ def create_finetune_request(
191200 lr_scheduler_args = LinearLRSchedulerArgs (min_lr_ratio = min_lr_ratio ),
192201 )
193202
194- training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT ()
203+ training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT (
204+ train_on_inputs = train_on_inputs
205+ )
195206 if training_method == "dpo" :
196207 training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
197208
@@ -214,7 +225,6 @@ def create_finetune_request(
214225 wandb_base_url = wandb_base_url ,
215226 wandb_project_name = wandb_project_name ,
216227 wandb_name = wandb_name ,
217- train_on_inputs = train_on_inputs ,
218228 training_method = training_method_cls ,
219229 from_checkpoint = from_checkpoint ,
220230 )
@@ -319,7 +329,7 @@ def create(
319329 wandb_name : str | None = None ,
320330 verbose : bool = False ,
321331 model_limits : FinetuneTrainingLimits | None = None ,
322- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
332+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
323333 training_method : str = "sft" ,
324334 dpo_beta : float | None = None ,
325335 from_checkpoint : str | None = None ,
@@ -364,12 +374,12 @@ def create(
364374 Defaults to False.
365375 model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
366376 Defaults to None.
367- train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
377+ train_on_inputs (bool or "auto", optional ): Whether to mask the user messages in conversational data or prompts in instruction data.
368378 "auto" will automatically determine whether to mask the inputs based on the data format.
369379 For datasets with the "text" field (general format), inputs will not be masked.
370380 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
371381 (Instruction format), inputs will be masked.
372- Defaults to "auto".
382+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
373383 training_method (str, optional): Training method. Defaults to "sft".
374384 Supported methods: "sft", "dpo".
375385 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -707,7 +717,7 @@ async def create(
707717 wandb_name : str | None = None ,
708718 verbose : bool = False ,
709719 model_limits : FinetuneTrainingLimits | None = None ,
710- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
720+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
711721 training_method : str = "sft" ,
712722 dpo_beta : float | None = None ,
713723 from_checkpoint : str | None = None ,
@@ -757,7 +767,7 @@ async def create(
757767 For datasets with the "text" field (general format), inputs will not be masked.
758768 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
759769 (Instruction format), inputs will be masked.
760- Defaults to "auto".
770+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
761771 training_method (str, optional): Training method. Defaults to "sft".
762772 Supported methods: "sft", "dpo".
763773 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
0 commit comments