Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def fine_tuning(ctx: click.Context) -> None:
default="all-linear",
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
)
@click.option(
"--training-method",
type=click.Choice(["sft", "dpo"]),
default="sft",
help="Training method to use. Options: sft (supervised fine-tuning), dpo (direct preference optimization)",
)
@click.option(
"--dpo-beta",
type=float,
default=0.1,
help="Beta parameter for DPO training (only used when training-method is 'dpo')",
)
@click.option(
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
)
Expand Down Expand Up @@ -152,6 +164,8 @@ def create(
wandb_name: str,
confirm: bool,
train_on_inputs: bool | Literal["auto"],
training_method: str,
dpo_beta: float,
) -> None:
"""Start fine-tuning"""
client: Together = ctx.obj
Expand Down Expand Up @@ -180,6 +194,8 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
)

model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
Expand Down
8 changes: 8 additions & 0 deletions src/together/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,20 @@ class DatasetFormat(enum.Enum):
GENERAL = "general"
CONVERSATION = "conversation"
INSTRUCTION = "instruction"
PREFERENCE = "preference"
PREFERENCE_OPENAI = "preference_openai"


JSONL_REQUIRED_COLUMNS_MAP = {
DatasetFormat.GENERAL: ["text"],
DatasetFormat.CONVERSATION: ["messages"],
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
DatasetFormat.PREFERENCE: ["chosen", "rejected"],
DatasetFormat.PREFERENCE_OPENAI: [
"input",
"preferred_output",
"non_preferred_output",
],
}
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
29 changes: 26 additions & 3 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Literal
from typing import Literal, Union

from rich import print as rprint

Expand All @@ -22,6 +22,8 @@
TrainingType,
FinetuneLRScheduler,
FinetuneLinearLRSchedulerArgs,
TrainingMethodDPO,
TrainingMethodSFT,
)
from together.types.finetune import DownloadCheckpointType
from together.utils import log_warn_once, normalize_key
Expand Down Expand Up @@ -52,6 +54,8 @@ def createFinetuneRequest(
wandb_project_name: str | None = None,
wandb_name: str | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
) -> FinetuneRequest:
if batch_size == "max":
log_warn_once(
Expand Down Expand Up @@ -105,6 +109,12 @@ def createFinetuneRequest(
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: since you're using the | notation to specify union types above, I would use it here as well and remove the redundant import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

TrainingMethodSFT()
)
if training_method == "dpo":
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)

finetune_request = FinetuneRequest(
model=model,
training_file=training_file,
Expand All @@ -125,6 +135,7 @@ def createFinetuneRequest(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method_cls,
)

return finetune_request
Expand Down Expand Up @@ -162,6 +173,8 @@ def create(
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
) -> FinetuneResponse:
"""
Method to initiate a fine-tuning job
Expand Down Expand Up @@ -207,6 +220,9 @@ def create(
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand All @@ -218,7 +234,6 @@ def create(

if model_limits is None:
model_limits = self.get_model_limits(model=model)

finetune_request = createFinetuneRequest(
model_limits=model_limits,
training_file=training_file,
Expand All @@ -244,6 +259,8 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
)

if verbose:
Expand All @@ -261,7 +278,6 @@ def create(
),
stream=False,
)

assert isinstance(response, TogetherResponse)

return FinetuneResponse(**response.data)
Expand Down Expand Up @@ -503,6 +519,8 @@ async def create(
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
) -> FinetuneResponse:
"""
Async method to initiate a fine-tuning job
Expand Down Expand Up @@ -548,6 +566,9 @@ async def create(
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down Expand Up @@ -585,6 +606,8 @@ async def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
)

if verbose:
Expand Down
4 changes: 4 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
FileType,
)
from together.types.finetune import (
TrainingMethodDPO,
TrainingMethodSFT,
FinetuneDownloadResult,
FinetuneLinearLRSchedulerArgs,
FinetuneList,
Expand Down Expand Up @@ -79,6 +81,8 @@
"TrainingType",
"FullTrainingType",
"LoRATrainingType",
"TrainingMethodDPO",
"TrainingMethodSFT",
"RerankRequest",
"RerankResponse",
"FinetuneTrainingLimits",
Expand Down
31 changes: 30 additions & 1 deletion src/together/types/finetune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import List, Literal
from typing import List, Literal, Union

from pydantic import StrictBool, Field, validator, field_validator

Expand Down Expand Up @@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
type: str = "Lora"


class TrainingMethod(BaseModel):
"""
Training method type
"""

method: str


class TrainingMethodSFT(TrainingMethod):
"""
Training method type for SFT training
"""

method: str = "sft"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
method: str = "sft"
method: Literal["sft"] = "sft"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added



class TrainingMethodDPO(TrainingMethod):
"""
Training method type for DPO training
"""

method: str = "dpo"
dpo_beta: float | None = None


class FinetuneRequest(BaseModel):
"""
Fine-tune request type
Expand Down Expand Up @@ -178,6 +203,10 @@ class FinetuneRequest(BaseModel):
training_type: FullTrainingType | LoRATrainingType | None = None
# train on inputs
train_on_inputs: StrictBool | Literal["auto"] = "auto"
# training method
training_method: Union[TrainingMethodSFT, TrainingMethodDPO] = Field(
default_factory=TrainingMethodSFT
)


class FinetuneResponse(BaseModel):
Expand Down
Loading
Loading