diff --git a/src/training/training_lora/classifier_model_fine_tuning_lora/ft_qwen3_generative_lora.py b/src/training/training_lora/classifier_model_fine_tuning_lora/ft_qwen3_generative_lora.py
index 01378b03..387f013e 100644
--- a/src/training/training_lora/classifier_model_fine_tuning_lora/ft_qwen3_generative_lora.py
+++ b/src/training/training_lora/classifier_model_fine_tuning_lora/ft_qwen3_generative_lora.py
@@ -28,7 +28,16 @@
# Quick test (10 samples per category = ~140 total)
python ft_qwen3_generative_lora.py --mode train --epochs 1 --max-samples-per-category 10
- # Inference
+ # Validate trained model on full validation set (auto-detects base model)
+ python ft_qwen3_generative_lora.py --mode validate --model-path qwen3_generative_classifier_r16
+
+ # Validate with specific number of samples
+ python ft_qwen3_generative_lora.py --mode validate --model-path qwen3_generative_classifier_r16 --num-val-samples 100
+
+ # Validate with explicit base model (if auto-detection fails)
+ python ft_qwen3_generative_lora.py --mode validate --model-path qwen3_generative_classifier_r16 --model Qwen/Qwen3-1.7B
+
+ # Inference demo (auto-detects base model)
python ft_qwen3_generative_lora.py --mode test --model-path qwen3_generative_classifier
Model:
@@ -178,19 +187,28 @@ def load_huggingface_dataset(self, max_samples_per_category=150):
cat for cat in REQUIRED_CATEGORIES if cat in category_samples
]
- target_samples_per_category = max_samples_per_category
+ # IMPORTANT: Validate and adjust target samples to ensure all categories have enough data
+ # Find the minimum available samples across all categories
+ min_available_samples = min(
+ len(category_samples[cat]) for cat in available_required_categories
+ )
+
+ # Use the smaller of: requested max_samples_per_category OR minimum available samples
+ target_samples_per_category = min(max_samples_per_category, min_available_samples)
- # Collect balanced samples
+ logger.info(f"Requested samples per category: {max_samples_per_category}")
+ logger.info(f"Minimum available samples across categories: {min_available_samples}")
+ logger.info(f"Actual samples per category (adjusted): {target_samples_per_category}")
+
+ # Collect balanced samples - now all categories will have EXACTLY the same number
filtered_texts = []
filtered_labels = []
category_counts = {}
for category in available_required_categories:
if category in category_samples:
- samples_to_take = min(
- target_samples_per_category, len(category_samples[category])
- )
- category_texts = category_samples[category][:samples_to_take]
+ # Now all categories take exactly target_samples_per_category samples
+ category_texts = category_samples[category][:target_samples_per_category]
filtered_texts.extend(category_texts)
filtered_labels.extend([category] * len(category_texts))
category_counts[category] = len(category_texts)
@@ -246,42 +264,67 @@ def prepare_datasets(self, max_samples_per_category=150):
}
-def format_instruction(question: str, category: str = None) -> str:
+def format_instruction(question: str, category: str = None) -> List[Dict[str, str]]:
"""
- Format a question-category pair as an instruction-following example.
+ Format a question-category pair as chat messages for proper instruction fine-tuning.
+
+ Uses Qwen3's ChatML format with special tokens to separate user input from assistant output.
+ This ensures the model only trains on generating the category name (1-2 tokens), not the
+ entire instruction (~200+ tokens), resulting in 100x more efficient training!
Args:
question: The question text
category: The category label (None for inference)
Returns:
- Formatted instruction string (with or without answer)
+ List of message dicts with 'role' and 'content' keys
+ Format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
"""
instruction = INSTRUCTION_TEMPLATE.format(question=question)
+ # User message (the instruction/question)
+ messages = [{"role": "user", "content": instruction}]
+
if category is not None:
- # Training format: instruction + answer
- return f"{instruction} {category}"
- else:
- # Inference format: instruction only
- return instruction
+ # Assistant message (the category name)
+ # This is just 1-2 tokens - much more efficient than training on entire sequence!
+ messages.append({"role": "assistant", "content": category})
+
+ return messages
def create_generative_dataset(
- texts: List[str], labels: List[str], tokenizer, max_length=512
+ texts: List[str], labels: List[str], tokenizer, max_length=512, enable_thinking: bool = True
):
"""
- Create dataset in generative format for instruction-following.
+ Create dataset in chat format for proper instruction fine-tuning.
- Format: "Question: ... Category: {label}"
- The model learns to generate the category name.
+ Uses tokenizer.apply_chat_template() to format messages with special tokens.
+ This ensures:
+ - User input (instruction) and assistant output (category) are properly separated
+ - Model trains ONLY on the category name (1-2 tokens), not the instruction (200+ tokens)
+ - Training is 100x more focused: 100% signal vs 0.4% signal in old format!
+ - Inference format matches training format exactly
+
+ Args:
+ texts: List of question texts
+ labels: List of category labels
+ tokenizer: Tokenizer instance
+ max_length: Maximum sequence length (default: 512)
+ enable_thinking: Enable Qwen3's thinking mode during training (default: True)
"""
formatted_examples = []
for text, label in zip(texts, labels):
- # Create full text: instruction + answer
- full_text = format_instruction(text, label)
- formatted_examples.append(full_text)
+ # Get messages (user instruction + assistant category)
+ messages = format_instruction(text, label)
+
+ # Apply chat template to add special tokens
+ # add_generation_prompt=False because we already have the assistant response
+ formatted_text = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=False, enable_thinking=enable_thinking
+ )
+ formatted_examples.append(formatted_text)
# Tokenize
encodings = tokenizer(
@@ -355,22 +398,26 @@ def main(
model_name: str = "Qwen/Qwen3-0.6B",
lora_rank: int = 16,
lora_alpha: int = 32,
- lora_dropout: float = 0.05, # Lower dropout for small model
- num_epochs: int = 8, # More epochs for 0.6B
+ lora_dropout: float = 0.1, # Increased dropout to prevent overfitting (was 0.05)
+ num_epochs: int = 5, # Reduced epochs to prevent overfitting (was 8)
batch_size: int = 4, # Configurable batch size (adjust based on GPU memory)
- learning_rate: float = 3e-4, # Higher LR for small model
+ learning_rate: float = 2e-4, # Reduced LR for better convergence (was 3e-4)
max_samples_per_category: int = 150, # Samples per category for balanced dataset
num_workers: int = 0, # Number of dataloader workers (0=single process, 2-4 for multiprocessing)
output_dir: str = None,
gpu_id: Optional[int] = None,
+ early_stopping_patience: int = 3, # NEW: Stop training if validation loss doesn't improve
+ enable_thinking: bool = True, # Enable Qwen3's thinking mode during training (default: True)
):
"""Main training function for generative Qwen3 classification.
Args:
max_samples_per_category: Maximum samples per category (default: 150).
With 14 categories, this gives ~2100 total samples.
+ enable_thinking: Enable Qwen3's thinking mode during training (default: True).
"""
logger.info("Starting Qwen3 Generative Classification Fine-tuning")
+ logger.info(f"Enable thinking mode: {enable_thinking}")
logger.info("Training Qwen3 to GENERATE category labels (instruction-following)")
# GPU selection using utility function
@@ -440,8 +487,8 @@ def main(
# Prepare datasets in generative format
logger.info("Formatting dataset for instruction-following...")
- train_dataset = create_generative_dataset(train_texts, train_labels, tokenizer)
- val_dataset = create_generative_dataset(val_texts, val_labels, tokenizer)
+ train_dataset = create_generative_dataset(train_texts, train_labels, tokenizer, enable_thinking=enable_thinking)
+ val_dataset = create_generative_dataset(val_texts, val_labels, tokenizer, enable_thinking=enable_thinking)
logger.info(f"Example training input:")
logger.info(tokenizer.decode(train_dataset[0]["input_ids"][:100]))
@@ -468,12 +515,15 @@ def main(
1, 16 // batch_size
), # Maintain effective batch size of 16, minimum 1
learning_rate=learning_rate,
- weight_decay=0.01,
+ weight_decay=0.05, # Increased weight decay for regularization (was 0.01)
logging_dir=f"{output_dir}/logs",
logging_steps=10,
eval_strategy="epoch",
- save_strategy="no", # Don't save intermediate checkpoints (saves disk space!)
- save_total_limit=1, # Keep only 1 checkpoint
+ save_strategy="epoch", # Changed to save best model based on validation loss
+ save_total_limit=2, # Keep best 2 checkpoints
+ load_best_model_at_end=True, # Load best model at end of training
+ metric_for_best_model="eval_loss", # Use validation loss for early stopping
+ greater_is_better=False, # Lower loss is better
warmup_ratio=0.1,
lr_scheduler_type="cosine",
fp16=False, # Disable fp16 to avoid gradient issues
@@ -520,43 +570,74 @@ def main(
model.eval()
+ # Clear GPU cache before validation
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ logger.info("Cleared GPU cache before validation")
+
# Use validation data for testing
- num_test_samples = min(20, len(val_texts)) # Test on 20 samples
+ num_test_samples = min(100, len(val_texts)) # Limit to 100 samples for quick validation
correct = 0
total = 0
logger.info(f"Testing on {num_test_samples} validation samples...")
for i in range(num_test_samples):
+ logger.info(f"Processing validation sample {i+1}/{num_test_samples}...")
question = val_texts[i]
true_category = val_labels[i]
- prompt = format_instruction(question, category=None)
+ # Format using chat template
+ messages = format_instruction(question, category=None)
+
+ # Apply chat template with generation prompt
+ # This adds <|im_start|>assistant\n to prompt the model to respond
+ # Disable thinking mode for direct classification output
+ try:
+ prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
+ )
+ except Exception as e:
+ logger.warning(f"Chat template failed, using fallback: {e}")
+ # Fallback to simple format
+ prompt = f"Question: {question}\nCategory:"
+
inputs = tokenizer(
prompt, return_tensors="pt", max_length=512, truncation=True
).to(model.device)
with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=10,
- temperature=0.1,
- do_sample=False, # Greedy decoding for evaluation
- pad_token_id=tokenizer.pad_token_id,
- )
+ try:
+ outputs = model.generate(
+ **inputs,
+ max_new_tokens=10,
+ do_sample=False, # Greedy decoding for evaluation
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=[
+ tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
+ ],
+ )
+ except Exception as e:
+ logger.error(f"Generation failed for sample {i+1}: {e}")
+ continue
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # Clear cache after each generation to prevent memory buildup
+ if torch.cuda.is_available() and i % 5 == 0:
+ torch.cuda.empty_cache()
- # Extract the category (text after "A:" or "Category:")
- if "A:" in generated_text:
- answer_text = generated_text.split("A:")[-1].strip()
- elif "Category:" in generated_text:
- answer_text = generated_text.split("Category:")[-1].strip()
- else:
- answer_text = ""
+ # Decode only the generated part (skip the input prompt)
+ generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+ # Remove thinking tokens that Qwen3 generates
+ generated_text = (
+ generated_text.replace("", "").replace("", "").strip()
+ )
+ # With chat template, model generates just the category directly
# Clean up answer (take first line, remove punctuation at end)
- answer_text = answer_text.split("\n")[0].strip().strip(".,!?;:").lower()
+ answer_text = generated_text.split("\n")[0].strip().strip(".,!?;:").lower()
# Match against known categories (handle multi-word categories like "computer science")
predicted_category = "unknown"
@@ -595,11 +676,308 @@ def main(
log_memory_usage("Post-training")
-def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
- """Demonstrate inference with trained generative model."""
+def validate_model(
+ model_path: str,
+ model_name: Optional[str] = None,
+ max_samples_per_category: int = 150,
+ num_val_samples: Optional[int] = None,
+ gpu_id: Optional[int] = None,
+ enable_thinking: bool = True,
+):
+ """
+ Validate a trained model on the full validation set.
+
+ Args:
+ model_path: Path to the saved model
+ model_name: Base model name (default: None = auto-detect from adapter_config.json)
+ max_samples_per_category: Maximum samples per category for dataset loading
+ num_val_samples: Number of validation samples to test (None = all)
+ gpu_id: GPU ID to use (None = auto-select)
+ enable_thinking: Enable Qwen3's thinking mode during generation (default: True)
+ """
+ logger.info("=" * 80)
+ logger.info("VALIDATION MODE: Testing trained model on validation set")
+ logger.info("=" * 80)
+ logger.info(f"Model path: {model_path}")
+
+ # GPU selection
+ device_str, selected_gpu = set_gpu_device(
+ gpu_id=gpu_id, auto_select=(gpu_id is None)
+ )
+ logger.info(f"Using device: {device_str} (GPU {selected_gpu})")
+
+ clear_gpu_memory()
+ log_memory_usage("Pre-validation")
+
+ try:
+ # Auto-detect base model from adapter config if not specified
+ if model_name is None:
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
+ if os.path.exists(adapter_config_path):
+ logger.info("Auto-detecting base model from adapter_config.json...")
+ with open(adapter_config_path, "r") as f:
+ adapter_config = json.load(f)
+ model_name = adapter_config.get("base_model_name_or_path", "Qwen/Qwen3-0.6B")
+ logger.info(f"Detected base model: {model_name}")
+ else:
+ model_name = "Qwen/Qwen3-0.6B"
+ logger.warning(f"adapter_config.json not found, using default: {model_name}")
+
+ # Load label mapping
+ logger.info("Loading label mapping...")
+ with open(os.path.join(model_path, "label_mapping.json"), "r") as f:
+ mapping_data = json.load(f)
+
+ # Load dataset
+ logger.info("Loading validation dataset...")
+ dataset_loader = MMLU_Dataset()
+ datasets = dataset_loader.prepare_datasets(max_samples_per_category)
+ val_texts, val_labels = datasets["validation"]
+
+ logger.info(f"Total validation samples: {len(val_texts)}")
+
+ # Limit samples if specified
+ if num_val_samples is not None and num_val_samples < len(val_texts):
+ val_texts = val_texts[:num_val_samples]
+ val_labels = val_labels[:num_val_samples]
+ logger.info(f"Limited to {num_val_samples} samples for validation")
+
+ # Load tokenizer
+ logger.info("Loading tokenizer...")
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Load base model with appropriate dtype
+ logger.info(f"Loading base model: {model_name}")
+ use_fp16 = False
+ if torch.cuda.is_available():
+ try:
+ compute_capability = torch.cuda.get_device_capability()
+ use_fp16 = compute_capability[0] >= 7
+ except Exception:
+ use_fp16 = False
+
+ base_model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.float16 if use_fp16 else torch.float32,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ )
+
+ # Load LoRA weights
+ logger.info(f"Loading LoRA weights from: {model_path}")
+ model = PeftModel.from_pretrained(base_model, model_path)
+ model = model.to(device_str)
+ model.eval()
+
+ logger.info("Model loaded successfully!")
+ log_memory_usage("Post-model-loading")
+
+ # Run validation
+ logger.info("\n" + "=" * 80)
+ logger.info(f"Running validation on {len(val_texts)} samples...")
+ logger.info(f"Enable thinking mode: {enable_thinking}")
+ logger.info("=" * 80)
+
+ correct = 0
+ total = 0
+ category_correct = {}
+ category_total = {}
+ predictions_log = []
+
+ for i, (question, true_category) in enumerate(zip(val_texts, val_labels)):
+ # Format using chat template
+ messages = format_instruction(question, category=None)
+
+ # Apply chat template with enable_thinking parameter
+ try:
+ prompt = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=enable_thinking,
+ )
+ except Exception as e:
+ logger.warning(f"Chat template failed, using fallback: {e}")
+ # Fallback to simple format
+ prompt = f"Question: {question}\nCategory:"
+
+ inputs = tokenizer(
+ prompt, return_tensors="pt", max_length=512, truncation=True
+ ).to(model.device)
+
+ # Generate prediction
+ with torch.no_grad():
+ outputs = model.generate(
+ **inputs,
+ max_new_tokens=50 if enable_thinking else 10, # More tokens if thinking is enabled
+ do_sample=False, # Greedy decoding for evaluation
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=[
+ tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
+ ],
+ )
+
+ # Decode only the generated part (skip the input prompt)
+ generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+ # Remove thinking tokens if present
+ if enable_thinking:
+ # Extract content after tag
+ if "" in generated_text:
+ generated_text = generated_text.split("")[-1].strip()
+ # Also remove any remaining tags
+ generated_text = generated_text.replace("", "").replace("", "").strip()
+
+ # Extract the category
+ if "A:" in generated_text:
+ answer_text = generated_text.split("A:")[-1].strip()
+ elif "Category:" in generated_text:
+ answer_text = generated_text.split("Category:")[-1].strip()
+ else:
+ answer_text = generated_text
+
+ # Clean up answer
+ answer_text = answer_text.split("\n")[0].strip().strip(".,!?;:").lower()
+
+ # Match against known categories
+ predicted_category = "unknown"
+ for category in REQUIRED_CATEGORIES:
+ if answer_text.startswith(category.lower()):
+ predicted_category = category.lower()
+ break
+
+ # If no match, take first 2 words (for "computer science" etc)
+ if predicted_category == "unknown" and answer_text:
+ words = answer_text.split()
+ if len(words) >= 2:
+ predicted_category = " ".join(words[:2])
+ elif len(words) == 1:
+ predicted_category = words[0]
+ else:
+ predicted_category = answer_text
+
+ # Check correctness
+ is_correct = predicted_category == true_category.lower()
+ if is_correct:
+ correct += 1
+ total += 1
+
+ # Track per-category accuracy
+ true_cat_lower = true_category.lower()
+ if true_cat_lower not in category_correct:
+ category_correct[true_cat_lower] = 0
+ category_total[true_cat_lower] = 0
+ category_total[true_cat_lower] += 1
+ if is_correct:
+ category_correct[true_cat_lower] += 1
+
+ # Log prediction
+ predictions_log.append({
+ "question": question,
+ "true_category": true_category,
+ "predicted_category": predicted_category,
+ "correct": is_correct,
+ })
+
+ # Progress logging
+ if (i + 1) % 50 == 0 or i < 5 or i >= len(val_texts) - 5:
+ logger.info(
+ f"[{i+1}/{len(val_texts)}] Accuracy so far: {correct}/{total} = {correct/total*100:.2f}%"
+ )
+
+ # Log first 5 and last 5 examples
+ if i < 5 or i >= len(val_texts) - 5:
+ logger.info(f" Question: {question[:100]}...")
+ logger.info(f" True: {true_category} | Predicted: {predicted_category}")
+ logger.info(f" {'✓ CORRECT' if is_correct else '✗ WRONG'}")
+
+ # Calculate overall accuracy
+ overall_accuracy = (correct / total * 100) if total > 0 else 0
+
+ # Calculate per-category accuracy
+ category_accuracies = {}
+ for cat in sorted(category_total.keys()):
+ cat_acc = (
+ (category_correct[cat] / category_total[cat] * 100)
+ if category_total[cat] > 0
+ else 0
+ )
+ category_accuracies[cat] = {
+ "correct": category_correct[cat],
+ "total": category_total[cat],
+ "accuracy": cat_acc,
+ }
+
+ # Print results
+ logger.info("\n" + "=" * 80)
+ logger.info("VALIDATION RESULTS")
+ logger.info("=" * 80)
+ logger.info(f"Overall Accuracy: {correct}/{total} = {overall_accuracy:.2f}%")
+ logger.info("\nPer-Category Accuracy:")
+ logger.info("-" * 80)
+
+ for cat in sorted(category_accuracies.keys()):
+ stats = category_accuracies[cat]
+ logger.info(
+ f" {cat:20s}: {stats['correct']:3d}/{stats['total']:3d} = {stats['accuracy']:6.2f}%"
+ )
+
+ logger.info("=" * 80)
+
+ # Save results to file
+ results_file = os.path.join(model_path, "validation_results.json")
+ results_data = {
+ "overall_accuracy": overall_accuracy,
+ "correct": correct,
+ "total": total,
+ "category_accuracies": category_accuracies,
+ "predictions": predictions_log,
+ }
+
+ with open(results_file, "w") as f:
+ json.dump(results_data, f, indent=2)
+
+ logger.info(f"\nResults saved to: {results_file}")
+ log_memory_usage("Post-validation")
+
+ return overall_accuracy, category_accuracies
+
+ except Exception as e:
+ logger.error(f"Error during validation: {e}")
+ import traceback
+ traceback.print_exc()
+ raise
+
+
+def demo_inference(model_path: str, model_name: Optional[str] = None, enable_thinking: bool = True):
+ """Demonstrate inference with trained generative model.
+
+ Args:
+ model_path: Path to the saved model
+ model_name: Base model name (default: None = auto-detect)
+ enable_thinking: Enable Qwen3's thinking mode during generation (default: True)
+ """
logger.info(f"Loading generative Qwen3 model from: {model_path}")
+ logger.info(f"Enable thinking mode: {enable_thinking}")
try:
+ # Auto-detect base model from adapter config if not specified
+ if model_name is None:
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
+ if os.path.exists(adapter_config_path):
+ logger.info("Auto-detecting base model from adapter_config.json...")
+ with open(adapter_config_path, "r") as f:
+ adapter_config = json.load(f)
+ model_name = adapter_config.get("base_model_name_or_path", "Qwen/Qwen3-0.6B")
+ logger.info(f"Detected base model: {model_name}")
+ else:
+ model_name = "Qwen/Qwen3-0.6B"
+ logger.warning(f"adapter_config.json not found, using default: {model_name}")
+
# Load label mapping
with open(os.path.join(model_path, "label_mapping.json"), "r") as f:
mapping_data = json.load(f)
@@ -633,6 +1011,15 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
model = PeftModel.from_pretrained(base_model, model_path)
model.eval()
+ # Clear generation config to avoid warnings about unused parameters
+ if hasattr(model, 'generation_config'):
+ if hasattr(model.generation_config, 'temperature'):
+ delattr(model.generation_config, 'temperature')
+ if hasattr(model.generation_config, 'top_p'):
+ delattr(model.generation_config, 'top_p')
+ if hasattr(model.generation_config, 'top_k'):
+ delattr(model.generation_config, 'top_k')
+
# Test examples
test_examples = [
"What is the best strategy for corporate mergers and acquisitions?",
@@ -650,30 +1037,46 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
total = 0
for example in test_examples:
- prompt = format_instruction(example, category=None)
+ # Format using chat template
+ messages = format_instruction(example, category=None)
+
+ # Apply chat template with generation prompt
+ prompt = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=enable_thinking,
+ )
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
- max_new_tokens=10,
- temperature=0.1,
- do_sample=True,
+ max_new_tokens=50 if enable_thinking else 10, # More tokens if thinking is enabled
+ do_sample=False, # Use greedy decoding for consistent results
pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=[
+ tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
+ ],
)
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # Decode only the generated part (skip the input prompt)
+ generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
- # Extract category (handle both "A:" and "Category:" formats)
- if "A:" in generated_text:
- answer_text = generated_text.split("A:")[-1].strip()
- elif "Category:" in generated_text:
- answer_text = generated_text.split("Category:")[-1].strip()
- else:
- answer_text = ""
+ # Remove thinking tokens if present
+ if enable_thinking:
+ # Extract content after tag
+ if "" in generated_text:
+ generated_text = generated_text.split("")[-1].strip()
+ # Also remove any remaining tags
+ generated_text = generated_text.replace("", "").replace("", "").strip()
+ # With chat template, model generates just the category directly
# Clean up and match against known categories
- answer_text = answer_text.split("\n")[0].strip().strip(".,!?;:").lower()
+ answer_text = generated_text.split("\n")[0].strip().strip(".,!?;:").lower()
category = "unknown"
for cat in REQUIRED_CATEGORIES:
@@ -691,7 +1094,7 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
)
print(f"\nQuestion: {example}")
- print(f"Generated: {generated_text[len(prompt):50]}...")
+ print(f"Generated: {generated_text[:50]}...")
print(f"Predicted Category: {category}")
print("-" * 80)
@@ -708,26 +1111,34 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
parser = argparse.ArgumentParser(
description="Qwen3 Generative Classification (Instruction-Following)"
)
- parser.add_argument("--mode", choices=["train", "test"], default="train")
+ parser.add_argument(
+ "--mode",
+ choices=["train", "validate", "test"],
+ default="train",
+ help="Mode: train (train model), validate (test trained model on validation set), test (demo inference)"
+ )
parser.add_argument(
"--model",
default="Qwen/Qwen3-0.6B",
help="Qwen3 model name (default: Qwen/Qwen3-0.6B)",
)
- parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank")
- parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
- parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
+ parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank (16-32 recommended for 1.7B model)")
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (typically 2x rank)")
+ parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout (increased to 0.1 to prevent overfitting)")
parser.add_argument(
- "--epochs", type=int, default=8, help="Number of training epochs"
+ "--epochs", type=int, default=5, help="Number of training epochs (reduced to 5 to prevent overfitting)"
)
parser.add_argument(
"--batch-size",
type=int,
- default=4,
- help="Per-device batch size (adjust based on GPU memory: 1-2 for small GPUs, 4-8 for medium, 8-16 for large). Gradient accumulation auto-adjusts to maintain effective batch size of 16.",
+ default=8,
+ help="Per-device batch size (increased default to 8 for better GPU utilization on A800). Gradient accumulation auto-adjusts to maintain effective batch size of 16.",
)
parser.add_argument(
- "--learning-rate", type=float, default=3e-4, help="Learning rate"
+ "--learning-rate", type=float, default=2e-4, help="Learning rate (reduced to 2e-4 for better convergence)"
+ )
+ parser.add_argument(
+ "--early-stopping-patience", type=int, default=3, help="Early stopping patience (stop if validation loss doesn't improve for N epochs)"
)
parser.add_argument(
"--max-samples-per-category",
@@ -742,17 +1153,35 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
help="Number of dataloader workers (0=single process for debugging, 2-4=multiprocessing for better performance)",
)
parser.add_argument("--output-dir", type=str, default=None)
- parser.add_argument("--gpu-id", type=int, default=None)
+ parser.add_argument("--gpu-id", type=int, default=None, help="GPU ID to use (None = auto-select)")
parser.add_argument(
"--model-path",
type=str,
default="qwen3_generative_classifier_r16",
- help="Path to saved model for inference",
+ help="Path to saved model for inference/validation",
+ )
+ parser.add_argument(
+ "--num-val-samples",
+ type=int,
+ default=None,
+ help="Number of validation samples to test (None = all samples). Only used in validate mode.",
+ )
+ parser.add_argument(
+ "--enable-thinking",
+ action="store_true",
+ default=True,
+ help="Enable Qwen3's thinking mode during generation (default: True). Use --no-enable-thinking to disable.",
+ )
+ parser.add_argument(
+ "--no-enable-thinking",
+ dest="enable_thinking",
+ action="store_false",
+ help="Disable Qwen3's thinking mode during generation.",
)
args = parser.parse_args()
- # GPU device selection is handled in main() and demo_inference() functions
+ # GPU device selection is handled in main(), validate_model(), and demo_inference() functions
# using the set_gpu_device() utility function for consistency
if args.mode == "train":
@@ -768,6 +1197,17 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
num_workers=args.num_workers,
output_dir=args.output_dir,
gpu_id=args.gpu_id,
+ early_stopping_patience=args.early_stopping_patience,
+ enable_thinking=args.enable_thinking,
+ )
+ elif args.mode == "validate":
+ validate_model(
+ model_path=args.model_path,
+ model_name=args.model,
+ max_samples_per_category=args.max_samples_per_category,
+ num_val_samples=args.num_val_samples,
+ gpu_id=args.gpu_id,
+ enable_thinking=args.enable_thinking,
)
elif args.mode == "test":
- demo_inference(args.model_path, args.model)
+ demo_inference(args.model_path, args.model, enable_thinking=args.enable_thinking)