Skip to content

Commit 1ce13d6

Browse files
committed
tests
1 parent 2cbd4c0 commit 1ce13d6

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

src/together/resources/finetune.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,10 @@ def create_finetune_request(
203203
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
204204
)
205205

206-
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT(
207-
train_on_inputs=train_on_inputs
208-
)
209-
if training_method == "dpo":
206+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO
207+
if training_method == "sft":
208+
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
209+
elif training_method == "dpo":
210210
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
211211

212212
finetune_request = FinetuneRequest(

tests/unit/test_finetune_resources.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,32 @@ def test_bad_training_method():
281281
training_file=_TRAINING_FILE,
282282
training_method="NON_SFT",
283283
)
284+
285+
286+
@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None])
287+
def test_train_on_inputs_for_sft(train_on_inputs):
288+
request = create_finetune_request(
289+
model_limits=_MODEL_LIMITS,
290+
model=_MODEL_NAME,
291+
training_file=_TRAINING_FILE,
292+
training_method="sft",
293+
train_on_inputs=train_on_inputs,
294+
)
295+
assert request.training_method.method == "sft"
296+
if isinstance(train_on_inputs, bool):
297+
assert request.training_method.train_on_inputs is train_on_inputs
298+
else:
299+
assert request.training_method.train_on_inputs == "auto"
300+
301+
302+
def test_train_on_inputs_not_supported_for_dpo():
303+
with pytest.raises(
304+
ValueError, match="train_on_inputs is only supported for SFT training"
305+
):
306+
_ = create_finetune_request(
307+
model_limits=_MODEL_LIMITS,
308+
model=_MODEL_NAME,
309+
training_file=_TRAINING_FILE,
310+
training_method="dpo",
311+
train_on_inputs=True,
312+
)

0 commit comments

Comments
 (0)