Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ def fine_tuning(ctx: click.Context) -> None:
default="all-linear",
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
)
@click.option(
"--training-method",
type=click.Choice(["sft", "dpo"]),
default="sft",
help="Training method to use. Options: sft (supervised fine-tuning), dpo (Direct Preference Optimization)",
)
@click.option(
"--dpo-beta",
type=float,
default=0.1,
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
)
@click.option(
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
)
Expand Down Expand Up @@ -166,6 +178,8 @@ def create(
wandb_name: str,
confirm: bool,
train_on_inputs: bool | Literal["auto"],
training_method: str,
dpo_beta: float,
from_checkpoint: str,
) -> None:
"""Start fine-tuning"""
Expand Down Expand Up @@ -195,6 +209,8 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
from_checkpoint=from_checkpoint,
)

Expand Down
6 changes: 6 additions & 0 deletions src/together/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ class DatasetFormat(enum.Enum):
GENERAL = "general"
CONVERSATION = "conversation"
INSTRUCTION = "instruction"
PREFERENCE_OPENAI = "preference_openai"


JSONL_REQUIRED_COLUMNS_MAP = {
DatasetFormat.GENERAL: ["text"],
DatasetFormat.CONVERSATION: ["messages"],
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
DatasetFormat.PREFERENCE_OPENAI: [
"input",
"preferred_output",
"non_preferred_output",
],
}
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
37 changes: 35 additions & 2 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
TrainingType,
FinetuneLRScheduler,
FinetuneLinearLRSchedulerArgs,
TrainingMethodDPO,
TrainingMethodSFT,
FinetuneCheckpoint,
)
from together.types.finetune import (
Expand All @@ -39,6 +41,12 @@
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"


AVAILABLE_TRAINING_METHODS = {
TrainingMethodSFT().method,
TrainingMethodDPO().method,
}


def createFinetuneRequest(
model_limits: FinetuneTrainingLimits,
training_file: str,
Expand All @@ -64,8 +72,11 @@ def createFinetuneRequest(
wandb_project_name: str | None = None,
wandb_name: str | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
from_checkpoint: str | None = None,
) -> FinetuneRequest:

if batch_size == "max":
log_warn_once(
"Starting from together>=1.3.0, "
Expand Down Expand Up @@ -113,11 +124,20 @@ def createFinetuneRequest(
if weight_decay is not None and (weight_decay < 0):
raise ValueError("Weight decay should be non-negative")

if training_method not in AVAILABLE_TRAINING_METHODS:
raise ValueError(
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
)

lrScheduler = FinetuneLRScheduler(
lr_scheduler_type="linear",
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe annotate the type as training_method_cls: TrainingMethod? It's a bit clearer and more extensible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some issues with pre-commit checks when I tried to do this, as I remember

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future

if training_method == "dpo":
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)

finetune_request = FinetuneRequest(
model=model,
training_file=training_file,
Expand All @@ -138,6 +158,7 @@ def createFinetuneRequest(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method_cls,
from_checkpoint=from_checkpoint,
)

Expand Down Expand Up @@ -240,6 +261,8 @@ def create(
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
from_checkpoint: str | None = None,
) -> FinetuneResponse:
"""
Expand Down Expand Up @@ -286,6 +309,9 @@ def create(
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
The step value is optional, without it the final checkpoint will be used.
Expand All @@ -300,7 +326,6 @@ def create(

if model_limits is None:
model_limits = self.get_model_limits(model=model)

finetune_request = createFinetuneRequest(
model_limits=model_limits,
training_file=training_file,
Expand All @@ -326,6 +351,8 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
from_checkpoint=from_checkpoint,
)

Expand All @@ -344,7 +371,6 @@ def create(
),
stream=False,
)

assert isinstance(response, TogetherResponse)

return FinetuneResponse(**response.data)
Expand Down Expand Up @@ -608,6 +634,8 @@ async def create(
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
from_checkpoint: str | None = None,
) -> FinetuneResponse:
"""
Expand Down Expand Up @@ -654,6 +682,9 @@ async def create(
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
The step value is optional, without it the final checkpoint will be used.
Expand Down Expand Up @@ -694,6 +725,8 @@ async def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
from_checkpoint=from_checkpoint,
)

Expand Down
4 changes: 4 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
FileType,
)
from together.types.finetune import (
TrainingMethodDPO,
TrainingMethodSFT,
FinetuneCheckpoint,
FinetuneDownloadResult,
FinetuneLinearLRSchedulerArgs,
Expand Down Expand Up @@ -81,6 +83,8 @@
"TrainingType",
"FullTrainingType",
"LoRATrainingType",
"TrainingMethodDPO",
"TrainingMethodSFT",
"RerankRequest",
"RerankResponse",
"FinetuneTrainingLimits",
Expand Down
29 changes: 29 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
type: str = "Lora"


class TrainingMethod(BaseModel):
"""
Training method type
"""

method: str


class TrainingMethodSFT(TrainingMethod):
"""
Training method type for SFT training
"""

method: Literal["sft"] = "sft"


class TrainingMethodDPO(TrainingMethod):
"""
Training method type for DPO training
"""

method: Literal["dpo"] = "dpo"
dpo_beta: float | None = None


class FinetuneRequest(BaseModel):
"""
Fine-tune request type
Expand Down Expand Up @@ -178,6 +203,10 @@ class FinetuneRequest(BaseModel):
training_type: FullTrainingType | LoRATrainingType | None = None
# train on inputs
train_on_inputs: StrictBool | Literal["auto"] = "auto"
# training method
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
default_factory=TrainingMethodSFT
)
# from step
from_checkpoint: str

Expand Down
Loading
Loading