Skip to content
24 changes: 21 additions & 3 deletions examples/mcp-classifier-server/server_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,37 @@ def _prepare_category_tokens(self):
)

def _format_instruction(self, question: str) -> str:
"""Format a question using the instruction template."""
"""
Format a question using the instruction template with chat format.

Uses Qwen3's ChatML format to match the training format.
Returns the formatted prompt string ready for tokenization.
"""
# Build the instruction content
if self.instruction_template:
return self.instruction_template.format(question=question)
instruction_content = self.instruction_template.format(question=question)
else:
# Fallback template
return f"""You are an expert academic classifier. Classify the following question into exactly ONE category. Respond with ONLY the category name.
instruction_content = f"""You are an expert academic classifier. Classify the following question into exactly ONE category. Respond with ONLY the category name.

Categories: {', '.join(self.category_names)}

Now classify this question:
Q: {question}
A:"""

# Format as chat messages (user message only, for classification)
messages = [{"role": "user", "content": instruction_content}]

# Apply chat template with generation prompt
# This adds <|im_start|>assistant\n at the end to prompt the model to respond
# Disable thinking mode for direct classification output (Qwen3 is a thinking model)
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)

return prompt

def classify(self, text: str, with_probabilities: bool = False) -> dict[str, Any]:
"""
Classify text using the generative model.
Expand Down
30 changes: 24 additions & 6 deletions src/training/training_lora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@

## 📖 Overview

This directory contains **LoRA (Low-Rank Adaptation)** training scripts for fine-tuning transformer models on three classification tasks:
This directory contains **LoRA (Low-Rank Adaptation)** training scripts for fine-tuning transformer models on multiple tasks:

### Classification Tasks

- **Intent Classification** (`classifier_model_fine_tuning_lora/`)
- **PII Detection** (`pii_model_fine_tuning_lora/`)
- **Security Detection** (`prompt_guard_fine_tuning_lora/`)

### Problem Solving Tasks

- **MMLU-Pro Specialized Solvers** (`mmlu_pro_solver_lora/`) ⭐ NEW!
- Fine-tune Qwen3-0.6B models to solve graduate-level academic problems
- 6 specialized experts (math, science, humanities, law, etc.)
- Chain-of-Thought reasoning with baseline comparison
- Expected: 40-60% accuracy (vs 10% random baseline)

## 🧠 What is LoRA?

**LoRA (Low-Rank Adaptation)** is a parameter-efficient fine-tuning technique that:
Expand Down Expand Up @@ -60,22 +70,30 @@ Our LoRA implementation supports three transformer architectures:
src/training/training_lora/
├── README.md # This file
├── common_lora_utils.py # Shared utilities
├── classifier_model_fine_tuning_lora/ # Intent Classification
│ ├── ft_linear_lora.py # Training script
│ ├── ft_qwen3_generative_lora.py # Category classifier
│ ├── ft_linear_lora_verifier.go # Go verification
│ ├── train_cpu_optimized.sh # Training automation
│ └── go.mod
├── pii_model_fine_tuning_lora/ # PII Detection
│ ├── pii_bert_finetuning_lora.py # Training script
│ ├── pii_bert_finetuning_lora_verifier.go # Go verification
│ ├── train_cpu_optimized.sh # Training automation
│ ├── presidio_synth_dataset_v2.json # Training data
│ └── go.mod
└── prompt_guard_fine_tuning_lora/ # Security Detection
├── jailbreak_bert_finetuning_lora.py # Training script
├── jailbreak_bert_finetuning_lora_verifier.go # Go verification
├── train_cpu_optimized.sh # Training automation
└── go.mod
├── prompt_guard_fine_tuning_lora/ # Security Detection
│ ├── jailbreak_bert_finetuning_lora.py # Training script
│ ├── jailbreak_bert_finetuning_lora_verifier.go # Go verification
│ ├── train_cpu_optimized.sh # Training automation
│ └── go.mod
└── mmlu_pro_solver_lora/ # ⭐ MMLU-Pro Problem Solvers
├── ft_qwen3_mmlu_solver_lora[_no_leakage].py # Main training script, _no_leakage version has no MMLU-Pro data leakage
└── train_all_specialists[_no_leakage].sh # Batch training, _no_leakage version has no MMLU-Pro data leakage
```

## 🚀 Quick Start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,42 +246,61 @@ def prepare_datasets(self, max_samples_per_category=150):
}


def format_instruction(question: str, category: str = None) -> str:
def format_instruction(question: str, category: str = None) -> List[Dict[str, str]]:
"""
Format a question-category pair as an instruction-following example.
Format a question-category pair as chat messages for proper instruction fine-tuning.

Uses Qwen3's ChatML format with special tokens to separate user input from assistant output.
This ensures the model only trains on generating the category name (1-2 tokens), not the
entire instruction (~200+ tokens), resulting in 100x more efficient training!

Args:
question: The question text
category: The category label (None for inference)

Returns:
Formatted instruction string (with or without answer)
List of message dicts with 'role' and 'content' keys
Format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
"""
instruction = INSTRUCTION_TEMPLATE.format(question=question)

# User message (the instruction/question)
messages = [{"role": "user", "content": instruction}]

if category is not None:
# Training format: instruction + answer
return f"{instruction} {category}"
else:
# Inference format: instruction only
return instruction
# Assistant message (the category name)
# This is just 1-2 tokens - much more efficient than training on entire sequence!
messages.append({"role": "assistant", "content": category})

return messages


def create_generative_dataset(
texts: List[str], labels: List[str], tokenizer, max_length=512
):
"""
Create dataset in generative format for instruction-following.

Format: "Question: ... Category: {label}"
The model learns to generate the category name.
Create dataset in chat format for proper instruction fine-tuning.

Uses tokenizer.apply_chat_template() to format messages with special tokens.
This ensures:
- User input (instruction) and assistant output (category) are properly separated
- Model trains ONLY on the category name (1-2 tokens), not the instruction (200+ tokens)
- Training is 100x more focused: 100% signal vs 0.4% signal in old format!
- Inference format matches training format exactly
"""
formatted_examples = []

for text, label in zip(texts, labels):
# Create full text: instruction + answer
full_text = format_instruction(text, label)
formatted_examples.append(full_text)
# Get messages (user instruction + assistant category)
messages = format_instruction(text, label)

# Apply chat template to add special tokens
# add_generation_prompt=False because we already have the assistant response
# Disable thinking mode to train model for direct classification
formatted_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False, enable_thinking=False
)
formatted_examples.append(formatted_text)

# Tokenize
encodings = tokenizer(
Expand Down Expand Up @@ -521,7 +540,7 @@ def main(
model.eval()

# Use validation data for testing
num_test_samples = min(20, len(val_texts)) # Test on 20 samples
num_test_samples = min(200, len(val_texts)) # Test on 200 samples
correct = 0
total = 0

Expand All @@ -531,7 +550,16 @@ def main(
question = val_texts[i]
true_category = val_labels[i]

prompt = format_instruction(question, category=None)
# Format using chat template
messages = format_instruction(question, category=None)

# Apply chat template with generation prompt
# This adds <|im_start|>assistant\n to prompt the model to respond
# Disable thinking mode for direct classification output
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)

inputs = tokenizer(
prompt, return_tensors="pt", max_length=512, truncation=True
).to(model.device)
Expand All @@ -543,20 +571,24 @@ def main(
temperature=0.1,
do_sample=False, # Greedy decoding for evaluation
pad_token_id=tokenizer.pad_token_id,
eos_token_id=[
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|im_end|>"),
],
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Decode only the generated part (skip the input prompt)
generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

# Extract the category (text after "A:" or "Category:")
if "A:" in generated_text:
answer_text = generated_text.split("A:")[-1].strip()
elif "Category:" in generated_text:
answer_text = generated_text.split("Category:")[-1].strip()
else:
answer_text = ""
# Remove thinking tokens that Qwen3 generates
generated_text = (
generated_text.replace("<think>", "").replace("</think>", "").strip()
)

# With chat template, model generates just the category directly
# Clean up answer (take first line, remove punctuation at end)
answer_text = answer_text.split("\n")[0].strip().strip(".,!?;:").lower()
answer_text = generated_text.split("\n")[0].strip().strip(".,!?;:").lower()

# Match against known categories (handle multi-word categories like "computer science")
predicted_category = "unknown"
Expand Down Expand Up @@ -650,7 +682,18 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
total = 0

for example in test_examples:
prompt = format_instruction(example, category=None)
# Format using chat template
messages = format_instruction(example, category=None)

# Apply chat template with generation prompt
# Disable thinking mode for direct classification output
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
Expand All @@ -660,20 +703,24 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
temperature=0.1,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=[
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|im_end|>"),
],
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Decode only the generated part (skip the input prompt)
generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

# Extract category (handle both "A:" and "Category:" formats)
if "A:" in generated_text:
answer_text = generated_text.split("A:")[-1].strip()
elif "Category:" in generated_text:
answer_text = generated_text.split("Category:")[-1].strip()
else:
answer_text = ""
# Remove thinking tokens that Qwen3 generates
generated_text = (
generated_text.replace("<think>", "").replace("</think>", "").strip()
)

# With chat template, model generates just the category directly
# Clean up and match against known categories
answer_text = answer_text.split("\n")[0].strip().strip(".,!?;:").lower()
answer_text = generated_text.split("\n")[0].strip().strip(".,!?;:").lower()

category = "unknown"
for cat in REQUIRED_CATEGORIES:
Expand All @@ -691,7 +738,7 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
)

print(f"\nQuestion: {example}")
print(f"Generated: {generated_text[len(prompt):50]}...")
print(f"Generated: {generated_text[:50]}...")
print(f"Predicted Category: {category}")
print("-" * 80)

Expand Down
Loading
Loading