Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def fine_tuning(ctx: click.Context) -> None:
default="all-linear",
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
)
@click.option(
"--training-method",
type=click.Choice(["sft", "dpo"]),
default="sft",
help="Training method to use. Options: sft (supervised fine-tuning), dpo (Direct Preference Optimization)",
)
@click.option(
"--dpo-beta",
type=float,
default=0.1,
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
)
@click.option(
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
)
Expand Down Expand Up @@ -152,6 +164,8 @@ def create(
wandb_name: str,
confirm: bool,
train_on_inputs: bool | Literal["auto"],
training_method: str,
dpo_beta: float,
) -> None:
"""Start fine-tuning"""
client: Together = ctx.obj
Expand Down Expand Up @@ -180,6 +194,8 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
)

model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
Expand Down
6 changes: 6 additions & 0 deletions src/together/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ class DatasetFormat(enum.Enum):
GENERAL = "general"
CONVERSATION = "conversation"
INSTRUCTION = "instruction"
PREFERENCE_OPENAI = "preference_openai"


JSONL_REQUIRED_COLUMNS_MAP = {
DatasetFormat.GENERAL: ["text"],
DatasetFormat.CONVERSATION: ["messages"],
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
DatasetFormat.PREFERENCE_OPENAI: [
"input",
"preferred_output",
"non_preferred_output",
],
}
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
35 changes: 33 additions & 2 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
TrainingType,
FinetuneLRScheduler,
FinetuneLinearLRSchedulerArgs,
TrainingMethodDPO,
TrainingMethodSFT,
)
from together.types.finetune import DownloadCheckpointType
from together.utils import log_warn_once, normalize_key
Expand Down Expand Up @@ -52,7 +54,10 @@ def createFinetuneRequest(
wandb_project_name: str | None = None,
wandb_name: str | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
) -> FinetuneRequest:

if batch_size == "max":
log_warn_once(
"Starting from together>=1.3.0, "
Expand Down Expand Up @@ -100,11 +105,24 @@ def createFinetuneRequest(
if weight_decay is not None and (weight_decay < 0):
raise ValueError("Weight decay should be non-negative")

AVAILABLE_TRAINING_METHODS = {
TrainingMethodSFT().method,
TrainingMethodDPO().method,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a constant, can you move it to the top of the file (outside of the function and the class definition)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

if training_method not in AVAILABLE_TRAINING_METHODS:
raise ValueError(
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
)

lrScheduler = FinetuneLRScheduler(
lr_scheduler_type="linear",
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe annotate the type as training_method_cls: TrainingMethod? It's a bit clearer and more extensible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some issues with pre-commit checks when I tried to do this, as I remember

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future

if training_method == "dpo":
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)

finetune_request = FinetuneRequest(
model=model,
training_file=training_file,
Expand All @@ -125,6 +143,7 @@ def createFinetuneRequest(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method_cls,
)

return finetune_request
Expand Down Expand Up @@ -162,6 +181,8 @@ def create(
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
) -> FinetuneResponse:
"""
Method to initiate a fine-tuning job
Expand Down Expand Up @@ -207,6 +228,9 @@ def create(
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand All @@ -218,7 +242,6 @@ def create(

if model_limits is None:
model_limits = self.get_model_limits(model=model)

finetune_request = createFinetuneRequest(
model_limits=model_limits,
training_file=training_file,
Expand All @@ -244,6 +267,8 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
)

if verbose:
Expand All @@ -261,7 +286,6 @@ def create(
),
stream=False,
)

assert isinstance(response, TogetherResponse)

return FinetuneResponse(**response.data)
Expand Down Expand Up @@ -503,6 +527,8 @@ async def create(
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
training_method: str = "sft",
dpo_beta: float | None = None,
) -> FinetuneResponse:
"""
Async method to initiate a fine-tuning job
Expand Down Expand Up @@ -548,6 +574,9 @@ async def create(
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down Expand Up @@ -585,6 +614,8 @@ async def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method,
dpo_beta=dpo_beta,
)

if verbose:
Expand Down
4 changes: 4 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
FileType,
)
from together.types.finetune import (
TrainingMethodDPO,
TrainingMethodSFT,
FinetuneDownloadResult,
FinetuneLinearLRSchedulerArgs,
FinetuneList,
Expand Down Expand Up @@ -79,6 +81,8 @@
"TrainingType",
"FullTrainingType",
"LoRATrainingType",
"TrainingMethodDPO",
"TrainingMethodSFT",
"RerankRequest",
"RerankResponse",
"FinetuneTrainingLimits",
Expand Down
29 changes: 29 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
type: str = "Lora"


class TrainingMethod(BaseModel):
"""
Training method type
"""

method: str


class TrainingMethodSFT(TrainingMethod):
"""
Training method type for SFT training
"""

method: Literal["sft"] = "sft"


class TrainingMethodDPO(TrainingMethod):
"""
Training method type for DPO training
"""

method: Literal["dpo"] = "dpo"
dpo_beta: float | None = None


class FinetuneRequest(BaseModel):
"""
Fine-tune request type
Expand Down Expand Up @@ -178,6 +203,10 @@ class FinetuneRequest(BaseModel):
training_type: FullTrainingType | LoRATrainingType | None = None
# train on inputs
train_on_inputs: StrictBool | Literal["auto"] = "auto"
# training method
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
default_factory=TrainingMethodSFT
)


class FinetuneResponse(BaseModel):
Expand Down
Loading
Loading