Skip to content

Commit 0ea183f

Browse files
committed
migrate to sft_on_inputs, and change defaults to match
1 parent c09481b commit 0ea183f

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

src/together/cli/api/finetune.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

33
import json
4+
import re
45
from datetime import datetime, timezone
56
from textwrap import wrap
67
from typing import Any, Literal
7-
import re
88

99
import click
1010
from click.core import ParameterSource # type: ignore[attr-defined]
@@ -13,17 +13,17 @@
1313

1414
from together import Together
1515
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
16+
from together.types.finetune import (
17+
DownloadCheckpointType,
18+
FinetuneEventType,
19+
FinetuneTrainingLimits,
20+
)
1621
from together.utils import (
1722
finetune_price_to_dollars,
23+
format_timestamp,
1824
log_warn,
1925
log_warn_once,
2026
parse_timestamp,
21-
format_timestamp,
22-
)
23-
from together.types.finetune import (
24-
DownloadCheckpointType,
25-
FinetuneTrainingLimits,
26-
FinetuneEventType,
2727
)
2828

2929

@@ -348,9 +348,9 @@ def list(ctx: click.Context) -> None:
348348
"Model Output Name": "\n".join(wrap(i.output_name or "", width=30)),
349349
"Status": i.status,
350350
"Created At": i.created_at,
351-
"Price": f"""${finetune_price_to_dollars(
352-
float(str(i.total_price))
353-
)}""", # convert to string for mypy typing
351+
"Price": f"""${
352+
finetune_price_to_dollars(float(str(i.total_price)))
353+
}""", # convert to string for mypy typing
354354
}
355355
)
356356
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)

src/together/resources/finetune.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def create_finetune_request(
7777
wandb_base_url: str | None = None,
7878
wandb_project_name: str | None = None,
7979
wandb_name: str | None = None,
80-
train_on_inputs: bool | Literal["auto"] = "auto",
80+
train_on_inputs: bool | Literal["auto"] | None = None,
8181
training_method: str = "sft",
8282
dpo_beta: float | None = None,
8383
from_checkpoint: str | None = None,
@@ -174,6 +174,15 @@ def create_finetune_request(
174174
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
175175
)
176176

177+
if train_on_inputs is not None and training_method != "sft":
178+
raise ValueError("train_on_inputs is only supported for SFT training")
179+
180+
if train_on_inputs is None and training_method == "sft":
181+
log_warn_once(
182+
"train_on_inputs is not set for SFT training, it will be set to 'auto' automatically"
183+
)
184+
train_on_inputs = "auto"
185+
177186
lr_scheduler: FinetuneLRScheduler
178187
if lr_scheduler_type == "cosine":
179188
if scheduler_num_cycles <= 0.0:
@@ -191,7 +200,9 @@ def create_finetune_request(
191200
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
192201
)
193202

194-
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
203+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT(
204+
train_on_inputs=train_on_inputs
205+
)
195206
if training_method == "dpo":
196207
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
197208

@@ -214,7 +225,6 @@ def create_finetune_request(
214225
wandb_base_url=wandb_base_url,
215226
wandb_project_name=wandb_project_name,
216227
wandb_name=wandb_name,
217-
train_on_inputs=train_on_inputs,
218228
training_method=training_method_cls,
219229
from_checkpoint=from_checkpoint,
220230
)
@@ -319,7 +329,7 @@ def create(
319329
wandb_name: str | None = None,
320330
verbose: bool = False,
321331
model_limits: FinetuneTrainingLimits | None = None,
322-
train_on_inputs: bool | Literal["auto"] = "auto",
332+
train_on_inputs: bool | Literal["auto"] | None = None,
323333
training_method: str = "sft",
324334
dpo_beta: float | None = None,
325335
from_checkpoint: str | None = None,
@@ -364,12 +374,12 @@ def create(
364374
Defaults to False.
365375
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
366376
Defaults to None.
367-
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
377+
train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data.
368378
"auto" will automatically determine whether to mask the inputs based on the data format.
369379
For datasets with the "text" field (general format), inputs will not be masked.
370380
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
371381
(Instruction format), inputs will be masked.
372-
Defaults to "auto".
382+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
373383
training_method (str, optional): Training method. Defaults to "sft".
374384
Supported methods: "sft", "dpo".
375385
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -707,7 +717,7 @@ async def create(
707717
wandb_name: str | None = None,
708718
verbose: bool = False,
709719
model_limits: FinetuneTrainingLimits | None = None,
710-
train_on_inputs: bool | Literal["auto"] = "auto",
720+
train_on_inputs: bool | Literal["auto"] | None = None,
711721
training_method: str = "sft",
712722
dpo_beta: float | None = None,
713723
from_checkpoint: str | None = None,
@@ -757,7 +767,7 @@ async def create(
757767
For datasets with the "text" field (general format), inputs will not be masked.
758768
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
759769
(Instruction format), inputs will be masked.
760-
Defaults to "auto".
770+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
761771
training_method (str, optional): Training method. Defaults to "sft".
762772
Supported methods: "sft", "dpo".
763773
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

src/together/types/finetune.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from typing import List, Literal, Any
55

6-
from pydantic import StrictBool, Field, field_validator
6+
from pydantic import Field, StrictBool, field_validator
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -149,6 +149,7 @@ class TrainingMethodSFT(TrainingMethod):
149149
"""
150150

151151
method: Literal["sft"] = "sft"
152+
train_on_inputs: StrictBool | Literal["auto"] = "auto"
152153

153154

154155
class TrainingMethodDPO(TrainingMethod):
@@ -201,8 +202,6 @@ class FinetuneRequest(BaseModel):
201202
wandb_name: str | None = None
202203
# training type
203204
training_type: FullTrainingType | LoRATrainingType | None = None
204-
# train on inputs
205-
train_on_inputs: StrictBool | Literal["auto"] = "auto"
206205
# training method
207206
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208207
default_factory=TrainingMethodSFT

0 commit comments

Comments
 (0)