Skip to content

Commit 01e9b4b

Browse files
committed
fix chat template issue
Signed-off-by: Huamin Chen <[email protected]>
1 parent 8e2e77b commit 01e9b4b

File tree

5 files changed

+124
-68
lines changed

5 files changed

+124
-68
lines changed

examples/mcp-classifier-server/server_generative.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,9 @@ def _format_instruction(self, question: str) -> str:
402402

403403
# Apply chat template with generation prompt
404404
# This adds <|im_start|>assistant\n at the end to prompt the model to respond
405+
# Disable thinking mode for direct classification output (Qwen3 is a thinking model)
405406
prompt = self.tokenizer.apply_chat_template(
406-
messages, tokenize=False, add_generation_prompt=True
407+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
407408
)
408409

409410
return prompt

src/training/training_lora/classifier_model_fine_tuning_lora/ft_qwen3_generative_lora.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,9 @@ def create_generative_dataset(
296296

297297
# Apply chat template to add special tokens
298298
# add_generation_prompt=False because we already have the assistant response
299+
# Disable thinking mode to train model for direct classification
299300
formatted_text = tokenizer.apply_chat_template(
300-
messages, tokenize=False, add_generation_prompt=False
301+
messages, tokenize=False, add_generation_prompt=False, enable_thinking=False
301302
)
302303
formatted_examples.append(formatted_text)
303304

@@ -539,7 +540,7 @@ def main(
539540
model.eval()
540541

541542
# Use validation data for testing
542-
num_test_samples = min(20, len(val_texts)) # Test on 20 samples
543+
num_test_samples = min(200, len(val_texts)) # Test on 200 samples
543544
correct = 0
544545
total = 0
545546

@@ -554,8 +555,9 @@ def main(
554555

555556
# Apply chat template with generation prompt
556557
# This adds <|im_start|>assistant\n to prompt the model to respond
558+
# Disable thinking mode for direct classification output
557559
prompt = tokenizer.apply_chat_template(
558-
messages, tokenize=False, add_generation_prompt=True
560+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
559561
)
560562

561563
inputs = tokenizer(
@@ -579,6 +581,11 @@ def main(
579581
generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
580582
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
581583

584+
# Remove thinking tokens that Qwen3 generates
585+
generated_text = (
586+
generated_text.replace("<think>", "").replace("</think>", "").strip()
587+
)
588+
582589
# With chat template, model generates just the category directly
583590
# Clean up answer (take first line, remove punctuation at end)
584591
answer_text = generated_text.split("\n")[0].strip().strip(".,!?;:").lower()
@@ -679,8 +686,12 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
679686
messages = format_instruction(example, category=None)
680687

681688
# Apply chat template with generation prompt
689+
# Disable thinking mode for direct classification output
682690
prompt = tokenizer.apply_chat_template(
683-
messages, tokenize=False, add_generation_prompt=True
691+
messages,
692+
tokenize=False,
693+
add_generation_prompt=True,
694+
enable_thinking=False,
684695
)
685696

686697
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -702,6 +713,11 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
702713
generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
703714
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
704715

716+
# Remove thinking tokens that Qwen3 generates
717+
generated_text = (
718+
generated_text.replace("<think>", "").replace("</think>", "").strip()
719+
)
720+
705721
# With chat template, model generates just the category directly
706722
# Clean up and match against known categories
707723
answer_text = generated_text.split("\n")[0].strip().strip(".,!?;:").lower()

src/training/training_lora/mmlu_pro_solver_lora/ft_qwen3_mmlu_solver_lora.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,9 @@ def create_solver_dataset(
448448

449449
# Apply chat template to add special tokens
450450
# add_generation_prompt=False because we already have the assistant response
451+
# enable_thinking=False to train model for direct problem-solving without reasoning tokens
451452
formatted_text = tokenizer.apply_chat_template(
452-
messages, tokenize=False, add_generation_prompt=False
453+
messages, tokenize=False, add_generation_prompt=False, enable_thinking=False
453454
)
454455
formatted_examples.append(formatted_text)
455456

@@ -586,8 +587,9 @@ def evaluate_model_on_samples(
586587

587588
# Apply chat template with generation prompt
588589
# This adds <|im_start|>assistant\n at the end to prompt the model to respond
590+
# enable_thinking=False for direct answer generation without reasoning tokens
589591
prompt = tokenizer.apply_chat_template(
590-
messages, tokenize=False, add_generation_prompt=True
592+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
591593
)
592594

593595
inputs = tokenizer(
@@ -1147,8 +1149,12 @@ def demo_inference(
11471149
)
11481150

11491151
# Apply chat template with generation prompt
1152+
# enable_thinking=False for direct answer generation without reasoning tokens
11501153
prompt = tokenizer.apply_chat_template(
1151-
messages, tokenize=False, add_generation_prompt=True
1154+
messages,
1155+
tokenize=False,
1156+
add_generation_prompt=True,
1157+
enable_thinking=False,
11521158
)
11531159

11541160
inputs = tokenizer(

src/training/training_lora/mmlu_pro_solver_lora/ft_qwen3_mmlu_solver_lora_no_leakage.py

Lines changed: 77 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@
126126
logger = setup_logging()
127127

128128
# Training dataset mapping for each specialist model
129-
# NOTE: GSM8K and MATH are free-form datasets (no multiple-choice) - not compatible
129+
# NOTE: Supports both multiple-choice (ARC, SciQ, etc.) and free-form (GSM8K, MATH) datasets
130130
TRAINING_DATASETS = {
131131
"math-reasoner": {
132-
"datasets": ["arc"], # ARC has reasoning/STEM questions
133-
"description": "Reasoning and STEM problems",
132+
"datasets": ["gsm8k", "math", "arc"], # Math word problems + reasoning
133+
"description": "Mathematical reasoning and STEM problems",
134134
"target_mmlu_categories": ["math", "physics", "engineering"],
135135
},
136136
"science-expert": {
@@ -155,11 +155,12 @@
155155
},
156156
"generalist": {
157157
"datasets": [
158+
"gsm8k",
158159
"arc",
159160
"commonsenseqa",
160161
"truthfulqa",
161-
], # Removed GSM8K (no options)
162-
"description": "Mixed domains",
162+
], # Mixed multiple-choice and free-form
163+
"description": "Mixed domains (catch-all specialist)",
163164
"target_mmlu_categories": ["health", "other"],
164165
},
165166
}
@@ -294,34 +295,40 @@ def convert_bench_question_to_training_format(question_obj, dataset_name: str) -
294295
Dict with question, options, answer (as text), category, cot_content
295296
Returns None if the sample is invalid
296297
"""
297-
# Validate that we have options
298-
if not question_obj.options or len(question_obj.options) < 2:
299-
logger.warning(
300-
f"Skipping {dataset_name} question {question_obj.question_id}: "
301-
f"insufficient options ({len(question_obj.options) if question_obj.options else 0})"
302-
)
303-
return None
298+
# Check if this is a free-form question (no multiple choice options)
299+
has_options = question_obj.options and len(question_obj.options) >= 2
304300

305-
# Convert answer to actual text
306-
try:
307-
answer_text = convert_answer_to_text(
308-
question_obj.correct_answer, question_obj.options
309-
)
310-
except Exception as e:
311-
logger.warning(
312-
f"Skipping {dataset_name} question {question_obj.question_id}: "
313-
f"failed to convert answer: {e}"
301+
if has_options:
302+
# Multiple-choice format: Convert answer to actual text
303+
try:
304+
answer_text = convert_answer_to_text(
305+
question_obj.correct_answer, question_obj.options
306+
)
307+
except Exception as e:
308+
logger.warning(
309+
f"Skipping {dataset_name} question {question_obj.question_id}: "
310+
f"failed to convert answer: {e}"
311+
)
312+
return None
313+
else:
314+
# Free-form format: Use answer as-is (GSM8K, MATH)
315+
answer_text = str(question_obj.correct_answer)
316+
logger.debug(
317+
f"Free-form question from {dataset_name}: "
318+
f"{question_obj.question_id} (no multiple-choice options)"
314319
)
315-
return None
316320

317321
return {
318322
"question": question_obj.question,
319-
"options": question_obj.options,
323+
"options": (
324+
question_obj.options if has_options else []
325+
), # Empty list for free-form
320326
"answer": answer_text, # Now always actual text, not letter/index
321327
"category": question_obj.category,
322328
"cot_content": question_obj.cot_content,
323329
"source_dataset": dataset_name,
324330
"question_id": question_obj.question_id,
331+
"is_free_form": not has_options, # Flag to indicate answer format
325332
}
326333

327334

@@ -513,9 +520,11 @@ def format_instruction(
513520
Uses Qwen3's ChatML format with special tokens to separate user input from assistant output.
514521
This ensures the model only trains on generating the answer, not the question.
515522
523+
Supports both multiple-choice (with options) and free-form (without options) formats.
524+
516525
Args:
517526
question: The question text
518-
options: List of answer options
527+
options: List of answer options (empty list for free-form questions)
519528
answer: The correct answer TEXT (actual option content) or None for inference
520529
cot_content: Optional chain-of-thought reasoning from source dataset
521530
use_cot: Whether to use Chain-of-Thought format
@@ -524,31 +533,53 @@ def format_instruction(
524533
List of message dicts with 'role' and 'content' keys
525534
Format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
526535
"""
527-
options_text = format_options(options)
528-
instruction = COT_INSTRUCTION_TEMPLATE.format(
529-
question=question, options=options_text
530-
)
536+
# Determine if this is multiple-choice or free-form
537+
is_multiple_choice = options and len(options) >= 2
538+
539+
if is_multiple_choice:
540+
# Multiple-choice format
541+
options_text = format_options(options)
542+
instruction = COT_INSTRUCTION_TEMPLATE.format(
543+
question=question, options=options_text
544+
)
545+
else:
546+
# Free-form format (GSM8K, MATH, etc.)
547+
instruction = f"""You are an expert problem solver. Solve the following problem step by step, showing your reasoning clearly.
548+
549+
Problem: {question}
550+
551+
Instructions:
552+
1. Read the problem carefully and identify what is being asked
553+
2. Break down the problem into steps
554+
3. Solve step by step, showing your calculations and reasoning
555+
4. End with "The answer is [your_final_answer]"
556+
557+
For example, if the answer is 42, write: "The answer is 42\""""
531558

532559
# User message (the question/instruction)
533560
messages = [{"role": "user", "content": instruction}]
534561

535562
if answer is not None:
536-
# Find which option matches the answer text to get the letter
537-
answer_letter = None
538-
answer_lower = answer.lower().strip()
539-
for i, option in enumerate(options):
540-
if option.lower().strip() == answer_lower:
541-
answer_letter = chr(
542-
65 + i
543-
) # Convert index to letter (0->A, 1->B, etc.)
544-
break
563+
if is_multiple_choice:
564+
# Find which option matches the answer text to get the letter
565+
answer_letter = None
566+
answer_lower = answer.lower().strip()
567+
for i, option in enumerate(options):
568+
if option.lower().strip() == answer_lower:
569+
answer_letter = chr(
570+
65 + i
571+
) # Convert index to letter (0->A, 1->B, etc.)
572+
break
545573

546-
# If no exact match, still format but without letter
547-
if answer_letter is None:
548-
formatted_answer = f"The answer is {answer}"
549-
logger.warning(f"Could not find letter for answer: {answer}")
574+
# If no exact match, still format but without letter
575+
if answer_letter is None:
576+
formatted_answer = f"The answer is {answer}"
577+
logger.warning(f"Could not find letter for answer: {answer}")
578+
else:
579+
formatted_answer = f"The answer is {answer_letter}) {answer}"
550580
else:
551-
formatted_answer = f"The answer is {answer_letter}) {answer}"
581+
# Free-form answer (no letter)
582+
formatted_answer = f"The answer is {answer}"
552583

553584
# Assistant message (the answer)
554585
if use_cot and cot_content:
@@ -592,8 +623,9 @@ def create_solver_dataset(
592623

593624
# Apply chat template to add special tokens
594625
# add_generation_prompt=False because we already have the assistant response
626+
# enable_thinking=False to train model for direct problem-solving without reasoning tokens
595627
formatted_text = tokenizer.apply_chat_template(
596-
messages, tokenize=False, add_generation_prompt=False
628+
messages, tokenize=False, add_generation_prompt=False, enable_thinking=False
597629
)
598630
formatted_examples.append(formatted_text)
599631

@@ -761,8 +793,9 @@ def evaluate_model_on_mmlu_pro(
761793

762794
# Apply chat template with generation prompt
763795
# This adds <|im_start|>assistant\n at the end to prompt the model to respond
796+
# enable_thinking=False for direct answer generation without reasoning tokens
764797
prompt = tokenizer.apply_chat_template(
765-
messages, tokenize=False, add_generation_prompt=True
798+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
766799
)
767800

768801
inputs = tokenizer(

0 commit comments

Comments
 (0)