Skip to content

Commit 07073bd

Browse files
authored
feat: fine tune qwen3 for knowledge specialization (vllm-project#447)
* feat: add scripts to fine tune qwen3 for knowledge specialization Signed-off-by: Huamin Chen <[email protected]> * fix Signed-off-by: Huamin Chen <[email protected]> * fix Signed-off-by: Huamin Chen <[email protected]> * fix chat template issue Signed-off-by: Huamin Chen <[email protected]> * add open math reasoning dataset for training Signed-off-by: Huamin Chen <[email protected]> * add test mode to compare baseline and trained models Signed-off-by: Huamin Chen <[email protected]> * use SFT trainer Signed-off-by: Huamin Chen <[email protected]> * use SFT trainer Signed-off-by: Huamin Chen <[email protected]> * support bigger models than 0.6B Signed-off-by: Huamin Chen <[email protected]> * optimize gpu memory Signed-off-by: Huamin Chen <[email protected]> * fix train to test transition issue Signed-off-by: Huamin Chen <[email protected]> --------- Signed-off-by: Huamin Chen <[email protected]>
1 parent 47e9bd7 commit 07073bd

File tree

9 files changed

+4474
-46
lines changed

9 files changed

+4474
-46
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""
2+
OpenMathReasoning Dataset Implementation
3+
4+
NVIDIA's OpenMathReasoning dataset - high-quality math problems with detailed
5+
chain-of-thought solutions. Contains 5.68M rows across multiple splits.
6+
7+
This implementation uses the 'cot' split which has 3.2M examples with detailed reasoning.
8+
"""
9+
10+
import os
11+
import random
12+
import re
13+
import sys
14+
from typing import List, Optional, Tuple
15+
16+
import numpy as np
17+
import pandas as pd
18+
from datasets import load_dataset
19+
20+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21+
22+
from ..dataset_interface import DatasetInfo, DatasetInterface, Question
23+
24+
25+
class OpenMathReasoningDataset(DatasetInterface):
26+
"""OpenMathReasoning dataset implementation for advanced mathematical reasoning."""
27+
28+
def __init__(self):
29+
"""Initialize OpenMathReasoning dataset."""
30+
self._dataset_cache = None
31+
self._categories_cache = None
32+
33+
@property
34+
def dataset_name(self) -> str:
35+
return "OpenMathReasoning"
36+
37+
@property
38+
def supports_cot(self) -> bool:
39+
return True # Has detailed chain-of-thought solutions
40+
41+
def _load_raw_dataset(self, max_examples: int = 10000):
42+
"""
43+
Load raw OpenMathReasoning dataset from Hugging Face.
44+
45+
Args:
46+
max_examples: Maximum number of examples to load (default: 10000)
47+
This prevents loading all 3.2M rows unnecessarily.
48+
"""
49+
if self._dataset_cache is not None:
50+
return self._dataset_cache
51+
52+
# Use STREAMING mode to avoid downloading the full 3.2M dataset
53+
# This way we only fetch the examples we actually need
54+
print(f"Loading OpenMathReasoning: {max_examples} examples (out of 3.2M total)")
55+
print(f" Using streaming mode to avoid downloading full dataset...")
56+
57+
dataset_stream = load_dataset(
58+
"nvidia/OpenMathReasoning", split="cot", streaming=True
59+
)
60+
61+
# Take only the first max_examples from the stream
62+
examples = []
63+
for i, example in enumerate(dataset_stream):
64+
if i >= max_examples:
65+
break
66+
examples.append(example)
67+
if (i + 1) % 1000 == 0:
68+
print(f" Loaded {i + 1}/{max_examples} examples...", end="\r")
69+
70+
print(f"\n ✓ Loaded {len(examples)} examples (streamed, not cached)")
71+
self._dataset_cache = pd.DataFrame(examples)
72+
return self._dataset_cache
73+
74+
def _get_categories(self, max_examples: int = 10000) -> List[str]:
75+
"""Get available categories in OpenMathReasoning dataset."""
76+
if self._categories_cache is not None:
77+
return self._categories_cache
78+
79+
# OpenMathReasoning has problem_type and problem_source fields
80+
# We'll use problem_type as categories
81+
# Load a subset to discover categories
82+
df = self._load_raw_dataset(max_examples=max_examples)
83+
self._categories_cache = df["problem_type"].unique().tolist()
84+
return self._categories_cache
85+
86+
def get_available_categories(self) -> List[str]:
87+
"""Get list of all available categories in the dataset."""
88+
return self._get_categories()
89+
90+
def load_dataset(
91+
self,
92+
categories: Optional[List[str]] = None,
93+
samples_per_category: Optional[int] = None,
94+
seed: int = 42,
95+
max_cot_length: Optional[int] = None,
96+
) -> Tuple[List[Question], DatasetInfo]:
97+
"""
98+
Load OpenMathReasoning dataset with optional filtering and sampling.
99+
100+
Args:
101+
categories: Filter by problem types
102+
samples_per_category: Number of samples per category
103+
seed: Random seed for sampling
104+
max_cot_length: Maximum character length for CoT solutions (for memory efficiency)
105+
"""
106+
# Calculate how many examples we need to load
107+
# If samples_per_category is specified, we can limit loading
108+
# Use a buffer factor based on whether we're filtering by length
109+
if samples_per_category:
110+
# If filtering by length, load more samples to compensate
111+
buffer_factor = 15 if max_cot_length else 3
112+
estimated_needed = samples_per_category * 3 * buffer_factor
113+
max_to_load = min(
114+
estimated_needed, 100000
115+
) # Cap at 100k for length filtering
116+
else:
117+
# Load more if no limit specified
118+
max_to_load = 50000 # Still cap to avoid loading all 3.2M
119+
120+
df = self._load_raw_dataset(max_examples=max_to_load)
121+
available_categories = self._get_categories(max_examples=max_to_load)
122+
123+
# Filter by CoT length if specified (for memory-efficient training)
124+
if max_cot_length:
125+
print(
126+
f"\n 📏 Filtering samples by CoT length (max: {max_cot_length} chars)"
127+
)
128+
original_count = len(df)
129+
df["cot_length"] = df["generated_solution"].str.len()
130+
df = df[df["cot_length"] <= max_cot_length]
131+
print(
132+
f" ✓ Kept {len(df)}/{original_count} samples ({len(df)/original_count*100:.1f}%) after length filtering"
133+
)
134+
135+
# Print distribution stats
136+
if len(df) > 0:
137+
print(f" 📊 CoT Length Stats (filtered):")
138+
print(f" Min: {df['cot_length'].min()} chars")
139+
print(f" Max: {df['cot_length'].max()} chars")
140+
print(f" Mean: {df['cot_length'].mean():.0f} chars")
141+
print(f" Median: {df['cot_length'].median():.0f} chars")
142+
143+
# Filter categories if specified
144+
if categories:
145+
missing_categories = set(categories) - set(available_categories)
146+
if missing_categories:
147+
raise ValueError(
148+
f"Categories not found: {missing_categories}. "
149+
f"Available: {available_categories}"
150+
)
151+
df = df[df["problem_type"].isin(categories)]
152+
selected_categories = categories
153+
else:
154+
selected_categories = available_categories
155+
156+
# Sample questions if specified (per category)
157+
if samples_per_category:
158+
np.random.seed(seed)
159+
random.seed(seed)
160+
161+
sampled_dfs = []
162+
for category in selected_categories:
163+
category_df = df[df["problem_type"] == category]
164+
sample_size = min(samples_per_category, len(category_df))
165+
if sample_size > 0:
166+
sampled_df = category_df.sample(n=sample_size, random_state=seed)
167+
sampled_dfs.append(sampled_df)
168+
169+
if sampled_dfs:
170+
df = pd.concat(sampled_dfs, ignore_index=True)
171+
else:
172+
df = pd.DataFrame()
173+
174+
# Convert to Question objects
175+
questions = []
176+
for _, row in df.iterrows():
177+
problem_text = row["problem"]
178+
solution_text = row["generated_solution"]
179+
expected_answer = row.get("expected_answer", "")
180+
problem_type = row.get("problem_type", "default")
181+
182+
# Clean the answer if needed
183+
correct_answer = str(expected_answer).strip()
184+
185+
question = Question(
186+
question_id=f"openmr_{len(questions)}",
187+
question=problem_text,
188+
options=[], # Free-form, no multiple choice
189+
correct_answer=correct_answer,
190+
category=problem_type,
191+
cot_content=solution_text, # Full solution with detailed reasoning
192+
metadata={
193+
"difficulty": "Advanced",
194+
"type": "math_problem",
195+
"problem_source": row.get("problem_source", "unknown"),
196+
"generation_model": row.get("generation_model", "unknown"),
197+
"pass_rate_72b_tir": row.get("pass_rate_72b_tir", "unknown"),
198+
},
199+
)
200+
questions.append(question)
201+
202+
dataset_info = DatasetInfo(
203+
name="OpenMathReasoning",
204+
description="NVIDIA's high-quality math problems with detailed chain-of-thought reasoning",
205+
categories=selected_categories,
206+
total_questions=len(questions),
207+
format_type="free_form",
208+
difficulty_level="Advanced",
209+
)
210+
211+
return questions, dataset_info
212+
213+
def format_prompt(self, question: Question, prompt_style: str = "plain") -> str:
214+
"""Format prompt for OpenMathReasoning questions."""
215+
if prompt_style == "plain":
216+
return f"""Solve this math problem:
217+
218+
{question.question}
219+
220+
Please provide your final answer in the following structured format:
221+
The answer is [your_final_answer]
222+
223+
For example: The answer is 42"""
224+
225+
elif prompt_style == "explicit_cot":
226+
return f"""Solve this math problem step by step, showing all your reasoning:
227+
228+
Problem: {question.question}
229+
230+
Please work through this step-by-step:
231+
1. Read the problem carefully and understand what is being asked
232+
2. Identify the given information and what needs to be found
233+
3. Choose appropriate methods and formulas
234+
4. Work through the solution step by step with clear explanations
235+
5. Verify your answer makes sense
236+
6. State your final answer clearly
237+
238+
Please provide your final answer in the following structured format:
239+
The answer is [your_final_answer]
240+
241+
For example: The answer is 42"""
242+
243+
else:
244+
raise ValueError(f"Unknown prompt style: {prompt_style}")

examples/mcp-classifier-server/server_generative.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,19 +378,37 @@ def _prepare_category_tokens(self):
378378
)
379379

380380
def _format_instruction(self, question: str) -> str:
381-
"""Format a question using the instruction template."""
381+
"""
382+
Format a question using the instruction template with chat format.
383+
384+
Uses Qwen3's ChatML format to match the training format.
385+
Returns the formatted prompt string ready for tokenization.
386+
"""
387+
# Build the instruction content
382388
if self.instruction_template:
383-
return self.instruction_template.format(question=question)
389+
instruction_content = self.instruction_template.format(question=question)
384390
else:
385391
# Fallback template
386-
return f"""You are an expert academic classifier. Classify the following question into exactly ONE category. Respond with ONLY the category name.
392+
instruction_content = f"""You are an expert academic classifier. Classify the following question into exactly ONE category. Respond with ONLY the category name.
387393
388394
Categories: {', '.join(self.category_names)}
389395
390396
Now classify this question:
391397
Q: {question}
392398
A:"""
393399

400+
# Format as chat messages (user message only, for classification)
401+
messages = [{"role": "user", "content": instruction_content}]
402+
403+
# Apply chat template with generation prompt
404+
# This adds <|im_start|>assistant\n at the end to prompt the model to respond
405+
# Disable thinking mode for direct classification output (Qwen3 is a thinking model)
406+
prompt = self.tokenizer.apply_chat_template(
407+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
408+
)
409+
410+
return prompt
411+
394412
def classify(self, text: str, with_probabilities: bool = False) -> dict[str, Any]:
395413
"""
396414
Classify text using the generative model.

src/training/training_lora/README.md

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@
22

33
## 📖 Overview
44

5-
This directory contains **LoRA (Low-Rank Adaptation)** training scripts for fine-tuning transformer models on three classification tasks:
5+
This directory contains **LoRA (Low-Rank Adaptation)** training scripts for fine-tuning transformer models on multiple tasks:
6+
7+
### Classification Tasks
68

79
- **Intent Classification** (`classifier_model_fine_tuning_lora/`)
810
- **PII Detection** (`pii_model_fine_tuning_lora/`)
911
- **Security Detection** (`prompt_guard_fine_tuning_lora/`)
1012

13+
### Problem Solving Tasks
14+
15+
- **MMLU-Pro Specialized Solvers** (`mmlu_pro_solver_lora/`) ⭐ NEW!
16+
- Fine-tune Qwen3-0.6B models to solve graduate-level academic problems
17+
- 6 specialized experts (math, science, humanities, law, etc.)
18+
- Chain-of-Thought reasoning with baseline comparison
19+
- Expected: 40-60% accuracy (vs 10% random baseline)
20+
1121
## 🧠 What is LoRA?
1222

1323
**LoRA (Low-Rank Adaptation)** is a parameter-efficient fine-tuning technique that:
@@ -60,22 +70,30 @@ Our LoRA implementation supports three transformer architectures:
6070
src/training/training_lora/
6171
├── README.md # This file
6272
├── common_lora_utils.py # Shared utilities
73+
6374
├── classifier_model_fine_tuning_lora/ # Intent Classification
6475
│ ├── ft_linear_lora.py # Training script
76+
│ ├── ft_qwen3_generative_lora.py # Category classifier
6577
│ ├── ft_linear_lora_verifier.go # Go verification
6678
│ ├── train_cpu_optimized.sh # Training automation
6779
│ └── go.mod
80+
6881
├── pii_model_fine_tuning_lora/ # PII Detection
6982
│ ├── pii_bert_finetuning_lora.py # Training script
7083
│ ├── pii_bert_finetuning_lora_verifier.go # Go verification
7184
│ ├── train_cpu_optimized.sh # Training automation
7285
│ ├── presidio_synth_dataset_v2.json # Training data
7386
│ └── go.mod
74-
└── prompt_guard_fine_tuning_lora/ # Security Detection
75-
├── jailbreak_bert_finetuning_lora.py # Training script
76-
├── jailbreak_bert_finetuning_lora_verifier.go # Go verification
77-
├── train_cpu_optimized.sh # Training automation
78-
└── go.mod
87+
88+
├── prompt_guard_fine_tuning_lora/ # Security Detection
89+
│ ├── jailbreak_bert_finetuning_lora.py # Training script
90+
│ ├── jailbreak_bert_finetuning_lora_verifier.go # Go verification
91+
│ ├── train_cpu_optimized.sh # Training automation
92+
│ └── go.mod
93+
94+
└── mmlu_pro_solver_lora/ # ⭐ MMLU-Pro Problem Solvers
95+
├── ft_qwen3_mmlu_solver_lora[_no_leakage].py # Main training script, _no_leakage version has no MMLU-Pro data leakage
96+
└── train_all_specialists[_no_leakage].sh # Batch training, _no_leakage version has no MMLU-Pro data leakage
7997
```
8098

8199
## 🚀 Quick Start

0 commit comments

Comments
 (0)