Skip to content

Commit 14ba59d

Browse files
committed
migrate to sft_on_inputs, and change defaults to match
1 parent 4eef896 commit 14ba59d

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

src/together/cli/api/finetune.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

33
import json
4+
import re
45
from datetime import datetime, timezone
56
from textwrap import wrap
67
from typing import Any, Literal
7-
import re
88

99
import click
1010
from click.core import ParameterSource # type: ignore[attr-defined]
@@ -13,17 +13,17 @@
1313

1414
from together import Together
1515
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
16+
from together.types.finetune import (
17+
DownloadCheckpointType,
18+
FinetuneEventType,
19+
FinetuneTrainingLimits,
20+
)
1621
from together.utils import (
1722
finetune_price_to_dollars,
23+
format_timestamp,
1824
log_warn,
1925
log_warn_once,
2026
parse_timestamp,
21-
format_timestamp,
22-
)
23-
from together.types.finetune import (
24-
DownloadCheckpointType,
25-
FinetuneTrainingLimits,
26-
FinetuneEventType,
2727
)
2828

2929

@@ -348,9 +348,9 @@ def list(ctx: click.Context) -> None:
348348
"Model Output Name": "\n".join(wrap(i.output_name or "", width=30)),
349349
"Status": i.status,
350350
"Created At": i.created_at,
351-
"Price": f"""${finetune_price_to_dollars(
352-
float(str(i.total_price))
353-
)}""", # convert to string for mypy typing
351+
"Price": f"""${
352+
finetune_price_to_dollars(float(str(i.total_price)))
353+
}""", # convert to string for mypy typing
354354
}
355355
)
356356
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)

src/together/resources/finetune.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_finetune_request(
6969
wandb_base_url: str | None = None,
7070
wandb_project_name: str | None = None,
7171
wandb_name: str | None = None,
72-
train_on_inputs: bool | Literal["auto"] = "auto",
72+
train_on_inputs: bool | Literal["auto"] | None = None,
7373
training_method: str = "sft",
7474
dpo_beta: float | None = None,
7575
from_checkpoint: str | None = None,
@@ -166,6 +166,15 @@ def create_finetune_request(
166166
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
167167
)
168168

169+
if train_on_inputs is not None and training_method != "sft":
170+
raise ValueError("train_on_inputs is only supported for SFT training")
171+
172+
if train_on_inputs is None and training_method == "sft":
173+
log_warn_once(
174+
"train_on_inputs is not set for SFT training, it will be set to 'auto' automatically"
175+
)
176+
train_on_inputs = "auto"
177+
169178
lr_scheduler: FinetuneLRScheduler
170179
if lr_scheduler_type == "cosine":
171180
if scheduler_num_cycles <= 0.0:
@@ -183,7 +192,9 @@ def create_finetune_request(
183192
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
184193
)
185194

186-
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
195+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT(
196+
train_on_inputs=train_on_inputs
197+
)
187198
if training_method == "dpo":
188199
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
189200

@@ -206,7 +217,6 @@ def create_finetune_request(
206217
wandb_base_url=wandb_base_url,
207218
wandb_project_name=wandb_project_name,
208219
wandb_name=wandb_name,
209-
train_on_inputs=train_on_inputs,
210220
training_method=training_method_cls,
211221
from_checkpoint=from_checkpoint,
212222
)
@@ -281,7 +291,7 @@ def create(
281291
wandb_name: str | None = None,
282292
verbose: bool = False,
283293
model_limits: FinetuneTrainingLimits | None = None,
284-
train_on_inputs: bool | Literal["auto"] = "auto",
294+
train_on_inputs: bool | Literal["auto"] | None = None,
285295
training_method: str = "sft",
286296
dpo_beta: float | None = None,
287297
from_checkpoint: str | None = None,
@@ -326,12 +336,12 @@ def create(
326336
Defaults to False.
327337
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
328338
Defaults to None.
329-
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
339+
train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data.
330340
"auto" will automatically determine whether to mask the inputs based on the data format.
331341
For datasets with the "text" field (general format), inputs will not be masked.
332342
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
333343
(Instruction format), inputs will be masked.
334-
Defaults to "auto".
344+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
335345
training_method (str, optional): Training method. Defaults to "sft".
336346
Supported methods: "sft", "dpo".
337347
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -693,7 +703,7 @@ async def create(
693703
wandb_name: str | None = None,
694704
verbose: bool = False,
695705
model_limits: FinetuneTrainingLimits | None = None,
696-
train_on_inputs: bool | Literal["auto"] = "auto",
706+
train_on_inputs: bool | Literal["auto"] | None = None,
697707
training_method: str = "sft",
698708
dpo_beta: float | None = None,
699709
from_checkpoint: str | None = None,
@@ -743,7 +753,7 @@ async def create(
743753
For datasets with the "text" field (general format), inputs will not be masked.
744754
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
745755
(Instruction format), inputs will be masked.
746-
Defaults to "auto".
756+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
747757
training_method (str, optional): Training method. Defaults to "sft".
748758
Supported methods: "sft", "dpo".
749759
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

src/together/types/finetune.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from typing import List, Literal, Any
55

6-
from pydantic import StrictBool, Field, field_validator
6+
from pydantic import Field, StrictBool, field_validator
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -149,6 +149,7 @@ class TrainingMethodSFT(TrainingMethod):
149149
"""
150150

151151
method: Literal["sft"] = "sft"
152+
train_on_inputs: StrictBool | Literal["auto"] = "auto"
152153

153154

154155
class TrainingMethodDPO(TrainingMethod):
@@ -201,8 +202,6 @@ class FinetuneRequest(BaseModel):
201202
wandb_name: str | None = None
202203
# training type
203204
training_type: FullTrainingType | LoRATrainingType | None = None
204-
# train on inputs
205-
train_on_inputs: StrictBool | Literal["auto"] = "auto"
206205
# training method
207206
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208207
default_factory=TrainingMethodSFT

0 commit comments

Comments
 (0)