Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tabulate import tabulate

from together import Together
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX, FROM_STEP_TYPE
from together.utils import (
finetune_price_to_dollars,
log_warn,
Expand Down Expand Up @@ -126,6 +126,12 @@ def fine_tuning(ctx: click.Context) -> None:
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
"`auto` will automatically determine whether to mask the inputs based on the data format.",
)
@click.option(
"--from-step",
type=FROM_STEP_TYPE,
default="final",
help="From which checkpoint start a fine-tuning job"
)
def create(
ctx: click.Context,
training_file: str,
Expand All @@ -152,6 +158,7 @@ def create(
wandb_name: str,
confirm: bool,
train_on_inputs: bool | Literal["auto"],
from_step: int | Literal["final"],
) -> None:
"""Start fine-tuning"""
client: Together = ctx.obj
Expand Down Expand Up @@ -180,6 +187,7 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
from_step=from_step,
)

model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
Expand Down
17 changes: 17 additions & 0 deletions src/together/cli/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,22 @@ def convert(
)


class FromStepParamType(click.ParamType):
name = "from_step"

def convert(self, value: str, param: click.Parameter | None, ctx: click.Context | None) -> int | Literal["final"] | None:
if value == "final":
return "final"
try:
return int(value)
except ValueError:
self.fail(
_("{value!r} is not a valid {type}.").format(
value=value, type=self.name
),
)


INT_WITH_MAX = AutoIntParamType()
BOOL_WITH_AUTO = BooleanWithAutoParamType()
FROM_STEP_TYPE = FromStepParamType()
8 changes: 8 additions & 0 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def createFinetuneRequest(
wandb_project_name: str | None = None,
wandb_name: str | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
from_step: int | Literal["final"] = "final"
) -> FinetuneRequest:
if batch_size == "max":
log_warn_once(
Expand Down Expand Up @@ -100,6 +101,9 @@ def createFinetuneRequest(
if weight_decay is not None and (weight_decay < 0):
raise ValueError("Weight decay should be non-negative")

if from_step == "final":
from_step = -1

lrScheduler = FinetuneLRScheduler(
lr_scheduler_type="linear",
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
Expand All @@ -125,6 +129,7 @@ def createFinetuneRequest(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
from_step=from_step,
)

return finetune_request
Expand Down Expand Up @@ -162,6 +167,7 @@ def create(
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
from_step: int | Literal["final"] = "final",
) -> FinetuneResponse:
"""
Method to initiate a fine-tuning job
Expand Down Expand Up @@ -207,6 +213,7 @@ def create(
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
from_step (int or "final"): From which checkpoint start a fine-tuning job

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down Expand Up @@ -244,6 +251,7 @@ def create(
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
from_step=from_step,
)

if verbose:
Expand Down
3 changes: 3 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class FinetuneRequest(BaseModel):
training_type: FullTrainingType | LoRATrainingType | None = None
# train on inputs
train_on_inputs: StrictBool | Literal["auto"] = "auto"
# from step
from_step: int | None = -1


class FinetuneResponse(BaseModel):
Expand Down Expand Up @@ -256,6 +258,7 @@ class FinetuneResponse(BaseModel):
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
training_file_size: int | None = Field(None, alias="TrainingFileSize")
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
from_step: int | None = "-1"

@field_validator("training_type")
@classmethod
Expand Down
Loading