Skip to content

Commit ed87efa

Browse files
authored
feat(data): add --train-only-last-turn option for thinking models (#419)
Add a new CLI argument to train_eagle3.py that enables training only on the last assistant turn in each conversation. This is useful for 'thinking' models (like DeepSeek-R1) or distilled datasets where the conversation history lacks the thought process present in the current generation. Changes: - Add train_only_last_turn parameter to GeneralParser, HarmonyParser, ThinkingParser - Add train_only_last_turn parameter to preprocess_conversations and build_eagle3_dataset - Add --train-only-last-turn CLI argument to train_eagle3.py Co-authored-by: yiliu <123>
1 parent c183a3a commit ed87efa

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

scripts/train_eagle3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
109109
action="store_true",
110110
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
111111
)
112+
dataset_group.add_argument(
113+
"--train-only-last-turn",
114+
action="store_true",
115+
help="If set, only the last assistant turn in each conversation contributes to the loss. "
116+
"Useful for thinking models where conversation history may lack thought processes.",
117+
)
112118
dataset_group.add_argument("--build-dataset-num-proc", type=int, default=8)
113119
dataset_group.add_argument(
114120
"--dataloader-num-workers",
@@ -422,6 +428,7 @@ def build_dataloaders(
422428
is_preformatted=args.is_preformatted,
423429
processor=processor,
424430
num_proc=args.build_dataset_num_proc,
431+
train_only_last_turn=args.train_only_last_turn,
425432
)
426433
vocab_mapping_path = generate_vocab_mapping_file(
427434
dataset=train_eagle3_dataset,
@@ -462,6 +469,7 @@ def build_dataloaders(
462469
processor=processor,
463470
num_proc=args.build_dataset_num_proc,
464471
is_preformatted=args.is_preformatted,
472+
train_only_last_turn=args.train_only_last_turn,
465473
)
466474
elif args.eval_hidden_states_path is not None:
467475
eval_eagle3_dataset = build_offline_eagle3_dataset(

specforge/data/parse.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def parse(
5454
conversation: "Conversation",
5555
max_length: int,
5656
preformatted: bool = False,
57+
train_only_last_turn: bool = False,
5758
**kwargs,
5859
) -> Dict[str, List[torch.Tensor]]:
5960
if not preformatted:
@@ -138,7 +139,12 @@ def parse(
138139
)
139140
input_ids = encoding.input_ids[0]
140141
loss_mask = torch.zeros(len(input_ids), dtype=torch.long)
141-
for match in re.finditer(assistant_pattern, conversation, re.DOTALL):
142+
143+
matches = list(re.finditer(assistant_pattern, conversation, re.DOTALL))
144+
if train_only_last_turn and matches:
145+
matches = [matches[-1]] # Only keep the last match
146+
147+
for match in matches:
142148
content_start_char = match.start(1)
143149
content_end_char = match.end(1)
144150

@@ -200,7 +206,11 @@ def build_single_turn_prompt(
200206
return prompt_text
201207

202208
def parse(
203-
self, conversation: "Conversation", max_length: int, preformatted: bool = False
209+
self,
210+
conversation: "Conversation",
211+
max_length: int,
212+
preformatted: bool = False,
213+
train_only_last_turn: bool = False,
204214
) -> List[torch.Tensor]:
205215
# conversation = process_harmony_conversations(conversation)
206216
if not preformatted:
@@ -243,7 +253,11 @@ def parse(
243253
)
244254

245255
# Find all matching segments
246-
for match in pattern.finditer(conversation):
256+
matches = list(pattern.finditer(conversation))
257+
if train_only_last_turn and matches:
258+
matches = [matches[-1]] # Only keep the last match
259+
260+
for match in matches:
247261
# match.start(0) is the start index of the full match (including `<|start|>assistant`)
248262
# match.start(1) is the start index of the first capture group (excluding `<|start|>assistant`)
249263
# match.end(1) is the end index of the content
@@ -288,10 +302,13 @@ def parse(
288302
conversation: "Conversation",
289303
max_length: int,
290304
preformatted: bool = False,
305+
train_only_last_turn: bool = False,
291306
**kwargs,
292307
) -> Dict[str, List[torch.Tensor]]:
293308
if self.chat_template.enable_thinking:
294309
kwargs["enable_thinking"] = True
295310
else:
296311
pass
297-
return super().parse(conversation, max_length, preformatted, **kwargs)
312+
return super().parse(
313+
conversation, max_length, preformatted, train_only_last_turn, **kwargs
314+
)

specforge/data/preprocessing.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def preprocess_conversations(
117117
chat_template: ChatTemplate,
118118
max_length: int = 2048,
119119
is_preformatted: bool = False,
120+
train_only_last_turn: bool = False,
120121
**kwargs,
121122
) -> Dict[str, List[torch.Tensor]]:
122123
"""
@@ -129,6 +130,7 @@ def preprocess_conversations(
129130
chat_template: The chat template to use for formatting/identifying spans.
130131
max_length: The maximum length of the tokenized input.
131132
is_preformatted: Whether the input is already formatted text strings.
133+
train_only_last_turn: If True, only the last assistant turn contributes to the loss.
132134
133135
Returns:
134136
A dictionary containing:
@@ -158,7 +160,11 @@ def preprocess_conversations(
158160
# if the source is None, skip it
159161
continue
160162
input_ids, loss_mask = parser.parse(
161-
source, max_length, preformatted=is_preformatted, **kwargs_item
163+
source,
164+
max_length,
165+
preformatted=is_preformatted,
166+
train_only_last_turn=train_only_last_turn,
167+
**kwargs_item,
162168
)
163169
results["input_ids"].append(input_ids[None, :])
164170
results["loss_mask"].append(loss_mask[None, :])
@@ -294,6 +300,7 @@ def build_eagle3_dataset(
294300
is_vlm: Optional[bool] = False,
295301
processor: Optional[ImageProcessingMixin] = None,
296302
is_preformatted: Optional[bool] = False,
303+
train_only_last_turn: Optional[bool] = False,
297304
) -> HFDataset:
298305
"""
299306
build eagle3 dataset
@@ -319,6 +326,8 @@ def build_eagle3_dataset(
319326
the assistant spans for loss mask generation.
320327
If True, expects "text" column with ready-to-train text.
321328
If False, expects "conversations" column with ShareGPT format.
329+
train_only_last_turn: If True, only the last assistant turn contributes to the loss.
330+
Useful for thinking models where history may not contain thoughts.
322331
323332
Returns:
324333
The processed HF dataset.
@@ -360,6 +369,7 @@ def preprocess_function(examples):
360369
template,
361370
max_length,
362371
is_preformatted=True,
372+
train_only_last_turn=train_only_last_turn,
363373
)
364374
else:
365375
# Handle ShareGPT conversations
@@ -376,6 +386,7 @@ def preprocess_function(examples):
376386
template,
377387
max_length,
378388
is_preformatted=False,
389+
train_only_last_turn=train_only_last_turn,
379390
**examples,
380391
)
381392

0 commit comments

Comments
 (0)