Skip to content

Commit cff84b0

Browse files
committed
Add from_step
1 parent 9e6120f commit cff84b0

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

src/together/cli/api/finetune.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tabulate import tabulate
1212

1313
from together import Together
14-
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
14+
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX, FROM_STEP_TYPE
1515
from together.utils import (
1616
finetune_price_to_dollars,
1717
log_warn,
@@ -126,6 +126,12 @@ def fine_tuning(ctx: click.Context) -> None:
126126
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
127127
"`auto` will automatically determine whether to mask the inputs based on the data format.",
128128
)
129+
@click.option(
130+
"--from-step",
131+
type=FROM_STEP_TYPE,
132+
default="final",
133+
help="From which checkpoint start a fine-tuning job"
134+
)
129135
def create(
130136
ctx: click.Context,
131137
training_file: str,
@@ -152,6 +158,7 @@ def create(
152158
wandb_name: str,
153159
confirm: bool,
154160
train_on_inputs: bool | Literal["auto"],
161+
from_step: int | Literal["final"],
155162
) -> None:
156163
"""Start fine-tuning"""
157164
client: Together = ctx.obj
@@ -180,6 +187,7 @@ def create(
180187
wandb_project_name=wandb_project_name,
181188
wandb_name=wandb_name,
182189
train_on_inputs=train_on_inputs,
190+
from_step=from_step,
183191
)
184192

185193
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(

src/together/cli/api/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,22 @@ def convert(
4747
)
4848

4949

50+
class FromStepParamType(click.ParamType):
51+
name = "from_step"
52+
53+
def convert(self, value: str, param: click.Parameter | None, ctx: click.Context | None) -> int | Literal["final"] | None:
54+
if value == "final":
55+
return "final"
56+
try:
57+
return int(value)
58+
except ValueError:
59+
self.fail(
60+
_("{value!r} is not a valid {type}.").format(
61+
value=value, type=self.name
62+
),
63+
)
64+
65+
5066
INT_WITH_MAX = AutoIntParamType()
5167
BOOL_WITH_AUTO = BooleanWithAutoParamType()
68+
FROM_STEP_TYPE = FromStepParamType()

src/together/resources/finetune.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def createFinetuneRequest(
5252
wandb_project_name: str | None = None,
5353
wandb_name: str | None = None,
5454
train_on_inputs: bool | Literal["auto"] = "auto",
55+
from_step: int | Literal["final"] = "final"
5556
) -> FinetuneRequest:
5657
if batch_size == "max":
5758
log_warn_once(
@@ -100,6 +101,9 @@ def createFinetuneRequest(
100101
if weight_decay is not None and (weight_decay < 0):
101102
raise ValueError("Weight decay should be non-negative")
102103

104+
if from_step == "final":
105+
from_step = -1
106+
103107
lrScheduler = FinetuneLRScheduler(
104108
lr_scheduler_type="linear",
105109
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
@@ -125,6 +129,7 @@ def createFinetuneRequest(
125129
wandb_project_name=wandb_project_name,
126130
wandb_name=wandb_name,
127131
train_on_inputs=train_on_inputs,
132+
from_step=from_step,
128133
)
129134

130135
return finetune_request
@@ -162,6 +167,7 @@ def create(
162167
verbose: bool = False,
163168
model_limits: FinetuneTrainingLimits | None = None,
164169
train_on_inputs: bool | Literal["auto"] = "auto",
170+
from_step: int | Literal["final"] = "final",
165171
) -> FinetuneResponse:
166172
"""
167173
Method to initiate a fine-tuning job
@@ -207,6 +213,7 @@ def create(
207213
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
208214
(Instruction format), inputs will be masked.
209215
Defaults to "auto".
216+
from_step (int or "final"): From which checkpoint start a fine-tuning job
210217
211218
Returns:
212219
FinetuneResponse: Object containing information about fine-tuning job.
@@ -244,6 +251,7 @@ def create(
244251
wandb_project_name=wandb_project_name,
245252
wandb_name=wandb_name,
246253
train_on_inputs=train_on_inputs,
254+
from_step=from_step,
247255
)
248256

249257
if verbose:

src/together/types/finetune.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class FinetuneRequest(BaseModel):
178178
training_type: FullTrainingType | LoRATrainingType | None = None
179179
# train on inputs
180180
train_on_inputs: StrictBool | Literal["auto"] = "auto"
181+
# from step
182+
from_step: int | None = -1
181183

182184

183185
class FinetuneResponse(BaseModel):
@@ -256,6 +258,7 @@ class FinetuneResponse(BaseModel):
256258
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
257259
training_file_size: int | None = Field(None, alias="TrainingFileSize")
258260
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
261+
from_step: int | None = "-1"
259262

260263
@field_validator("training_type")
261264
@classmethod

0 commit comments

Comments
 (0)