100100# Setup logging
101101logger = setup_logging ()
102102
103+ # Required categories to match legacy model (14 categories)
104+ REQUIRED_CATEGORIES = [
105+ "biology" ,
106+ "business" ,
107+ "chemistry" ,
108+ "computer science" ,
109+ "economics" ,
110+ "engineering" ,
111+ "health" ,
112+ "history" ,
113+ "law" ,
114+ "math" ,
115+ "other" ,
116+ "philosophy" ,
117+ "physics" ,
118+ "psychology" ,
119+ ]
120+
103121
104122def create_tokenizer_for_model (model_path : str , base_model_name : str = None ):
105123 """
@@ -135,7 +153,7 @@ def __init__(self, dataset_name="TIGER-Lab/MMLU-Pro"):
135153 self .id2label = {}
136154
137155 def load_huggingface_dataset (self , max_samples = 1000 ):
138- """Load the MMLU-Pro dataset from HuggingFace."""
156+ """Load the MMLU-Pro dataset from HuggingFace with balanced category sampling ."""
139157 logger .info (f"Loading dataset from HuggingFace: { self .dataset_name } " )
140158
141159 try :
@@ -145,17 +163,103 @@ def load_huggingface_dataset(self, max_samples=1000):
145163
146164 # Extract questions and categories from the test split
147165 # Note: MMLU-Pro typically uses 'test' split for training data
148- texts = dataset ["test" ]["question" ]
149- labels = dataset ["test" ]["category" ]
166+ all_texts = dataset ["test" ]["question" ]
167+ all_labels = dataset ["test" ]["category" ]
168+
169+ logger .info (f"Total samples in dataset: { len (all_texts )} " )
170+
171+ # Group samples by category
172+ category_samples = {}
173+ for text , label in zip (all_texts , all_labels ):
174+ if label not in category_samples :
175+ category_samples [label ] = []
176+ category_samples [label ].append (text )
177+
178+ logger .info (
179+ f"Available categories in dataset: { sorted (category_samples .keys ())} "
180+ )
181+ logger .info (f"Required categories: { REQUIRED_CATEGORIES } " )
182+
183+ # Check which required categories are missing
184+ missing_categories = set (REQUIRED_CATEGORIES ) - set (category_samples .keys ())
185+ if missing_categories :
186+ logger .warning (f"Missing categories in dataset: { missing_categories } " )
187+
188+ # Calculate samples per category for balanced sampling
189+ available_required_categories = [
190+ cat for cat in REQUIRED_CATEGORIES if cat in category_samples
191+ ]
150192
151- # Limit samples for faster training
152- if max_samples and len (texts ) > max_samples :
153- texts = texts [:max_samples ]
154- labels = labels [:max_samples ]
155- logger .info (f"Limited dataset to { max_samples } samples" )
193+ # Ensure minimum samples per category for stable training
194+ min_samples_per_category = max (
195+ 50 , max_samples // (len (available_required_categories ) * 2 )
196+ )
197+ target_samples_per_category = max_samples // len (
198+ available_required_categories
199+ )
156200
157- logger .info (f"Loaded { len (texts )} samples" )
158- return texts , labels
201+ logger .info (f"Available categories: { len (available_required_categories )} " )
202+ logger .info (f"Min samples per category: { min_samples_per_category } " )
203+ logger .info (f"Target samples per category: { target_samples_per_category } " )
204+
205+ # Collect balanced samples from required categories with improved strategy
206+ filtered_texts = []
207+ filtered_labels = []
208+ category_counts = {}
209+ insufficient_categories = []
210+
211+ # First pass: collect available samples for each category
212+ for category in available_required_categories :
213+ if category in category_samples :
214+ available_samples = len (category_samples [category ])
215+
216+ if available_samples < min_samples_per_category :
217+ insufficient_categories .append (category )
218+ samples_to_take = available_samples # Take all available
219+ else :
220+ samples_to_take = min (
221+ target_samples_per_category , available_samples
222+ )
223+
224+ category_texts = category_samples [category ][:samples_to_take ]
225+ filtered_texts .extend (category_texts )
226+ filtered_labels .extend ([category ] * len (category_texts ))
227+ category_counts [category ] = len (category_texts )
228+
229+ # Log insufficient categories
230+ if insufficient_categories :
231+ logger .warning (
232+ f"Categories with insufficient samples: { insufficient_categories } "
233+ )
234+ for cat in insufficient_categories :
235+ logger .warning (
236+ f" { cat } : only { category_counts .get (cat , 0 )} samples available"
237+ )
238+
239+ logger .info (f"Final category distribution: { category_counts } " )
240+ logger .info (f"Total filtered samples: { len (filtered_texts )} " )
241+
242+ # Ensure we have samples for all required categories
243+ missing_categories = set (available_required_categories ) - set (
244+ category_counts .keys ()
245+ )
246+ if missing_categories :
247+ logger .error (
248+ f"CRITICAL: Categories with no samples: { missing_categories } "
249+ )
250+
251+ # Validate minimum category coverage
252+ if (
253+ len (category_counts ) < len (REQUIRED_CATEGORIES ) * 0.8
254+ ): # At least 80% of categories
255+ logger .error (
256+ f"CRITICAL: Only { len (category_counts )} /{ len (REQUIRED_CATEGORIES )} categories have samples!"
257+ )
258+ logger .error (
259+ "This will result in poor model performance. Consider increasing max_samples or using a different dataset."
260+ )
261+
262+ return filtered_texts , filtered_labels
159263
160264 except Exception as e :
161265 logger .error (f"Error loading dataset: { e } " )
@@ -167,12 +271,20 @@ def prepare_datasets(self, max_samples=1000):
167271 # Load the dataset
168272 texts , labels = self .load_huggingface_dataset (max_samples )
169273
170- # Create label mapping
274+ # Create label mapping using required categories order for consistency
171275 unique_labels = sorted (list (set (labels )))
172- self .label2id = {label : idx for idx , label in enumerate (unique_labels )}
276+
277+ # Ensure we use the same order as legacy model for consistency
278+ ordered_labels = [cat for cat in REQUIRED_CATEGORIES if cat in unique_labels ]
279+ # Add any extra categories that might exist
280+ extra_labels = [cat for cat in unique_labels if cat not in REQUIRED_CATEGORIES ]
281+ final_labels = ordered_labels + sorted (extra_labels )
282+
283+ self .label2id = {label : idx for idx , label in enumerate (final_labels )}
173284 self .id2label = {idx : label for label , idx in self .label2id .items ()}
174285
175- logger .info (f"Found { len (unique_labels )} unique categories: { unique_labels } " )
286+ logger .info (f"Found { len (final_labels )} unique categories: { final_labels } " )
287+ logger .info (f"Label mapping: { self .label2id } " )
176288
177289 # Convert labels to IDs
178290 label_ids = [self .label2id [label ] for label in labels ]
@@ -245,9 +357,20 @@ def compute_loss(
245357 logits .view (- 1 , self .model .config .num_labels ), labels .view (- 1 )
246358 )
247359
248- # TODO: Add feature alignment loss when original model is available
360+ # Feature alignment loss to improve LoRA adaptation
249361 total_loss = classification_loss
250362
363+ if self .enable_feature_alignment :
364+ # Add L2 regularization on LoRA parameters to prevent overfitting
365+ l2_reg = 0.0
366+ for name , param in model .named_parameters ():
367+ if "lora_" in name and param .requires_grad :
368+ l2_reg += torch .norm (param , p = 2 )
369+
370+ # Add feature alignment loss
371+ alignment_loss = self .alignment_weight * l2_reg
372+ total_loss = classification_loss + alignment_loss
373+
251374 return (total_loss , outputs ) if return_outputs else total_loss
252375
253376
@@ -321,7 +444,7 @@ def main(
321444 lora_dropout : float = 0.1 ,
322445 num_epochs : int = 3 ,
323446 batch_size : int = 8 ,
324- learning_rate : float = 1e-4 ,
447+ learning_rate : float = 3e-5 , # Reduced from 1e-4 to prevent gradient explosion
325448 max_samples : int = 1000 ,
326449 output_dir : str = None ,
327450 enable_feature_alignment : bool = False ,
@@ -370,13 +493,12 @@ def main(
370493
371494 logger .info (f"Model will be saved to: { output_dir } " )
372495
373- # Training arguments
496+ # Training arguments optimized for LoRA sequence classification based on PEFT best practices
374497 training_args = TrainingArguments (
375498 output_dir = output_dir ,
376499 num_train_epochs = num_epochs ,
377500 per_device_train_batch_size = batch_size ,
378501 per_device_eval_batch_size = batch_size ,
379- warmup_steps = 100 ,
380502 weight_decay = 0.01 ,
381503 logging_dir = f"{ output_dir } /logs" ,
382504 logging_steps = 10 ,
@@ -386,6 +508,13 @@ def main(
386508 metric_for_best_model = "eval_f1" ,
387509 greater_is_better = True ,
388510 learning_rate = learning_rate ,
511+ # PEFT optimization: Enhanced stability measures
512+ max_grad_norm = 1.0 , # Gradient clipping to prevent explosion
513+ lr_scheduler_type = "cosine" , # More stable learning rate schedule for LoRA
514+ warmup_ratio = 0.06 , # PEFT recommended warmup ratio for sequence classification
515+ # Additional stability measures for intent classification
516+ dataloader_drop_last = False ,
517+ eval_accumulation_steps = 1 ,
389518 )
390519
391520 # Create trainer
@@ -419,18 +548,18 @@ def main(
419548 json .dump (label_mapping , f , indent = 2 )
420549
421550 logger .info (f"LoRA intent classification model saved to: { output_dir } " )
422- logger .info ("✅ Saved both label_mapping.json and category_mapping.json" )
551+ logger .info ("Saved both label_mapping.json and category_mapping.json" )
423552
424553 # Auto-merge LoRA adapter with base model for Rust compatibility
425- logger .info ("🔄 Auto-merging LoRA adapter with base model for Rust inference..." )
554+ logger .info ("Auto-merging LoRA adapter with base model for Rust inference..." )
426555 try :
427556 merged_output_dir = f"{ output_dir } _rust"
428557 merge_lora_adapter_to_full_model (output_dir , merged_output_dir , model_path )
429- logger .info (f"✅ Rust-compatible model saved to: { merged_output_dir } " )
430- logger .info (f" This model can be used with Rust candle-binding!" )
558+ logger .info (f"Rust-compatible model saved to: { merged_output_dir } " )
559+ logger .info (f"This model can be used with Rust candle-binding!" )
431560 except Exception as e :
432- logger .warning (f"⚠️ Auto-merge failed: { e } " )
433- logger .info (f" You can manually merge using a merge script" )
561+ logger .warning (f"Auto-merge failed: { e } " )
562+ logger .info (f"You can manually merge using a merge script" )
434563
435564 # Final evaluation
436565 logger .info ("Final evaluation on validation set..." )
@@ -448,7 +577,7 @@ def merge_lora_adapter_to_full_model(
448577 This function is automatically called after training to generate Rust-compatible models.
449578 """
450579
451- logger .info (f"🔄 Loading base model: { base_model_path } " )
580+ logger .info (f"Loading base model: { base_model_path } " )
452581
453582 # Load label mapping to get correct number of labels
454583 with open (os .path .join (lora_adapter_path , "label_mapping.json" ), "r" ) as f :
@@ -463,17 +592,17 @@ def merge_lora_adapter_to_full_model(
463592 # Load tokenizer with model-specific configuration
464593 tokenizer = create_tokenizer_for_model (base_model_path , base_model_path )
465594
466- logger .info (f"🔄 Loading LoRA adapter from: { lora_adapter_path } " )
595+ logger .info (f"Loading LoRA adapter from: { lora_adapter_path } " )
467596
468597 # Load LoRA model
469598 lora_model = PeftModel .from_pretrained (base_model , lora_adapter_path )
470599
471- logger .info ("🔄 Merging LoRA adapter with base model..." )
600+ logger .info ("Merging LoRA adapter with base model..." )
472601
473602 # Merge and unload LoRA
474603 merged_model = lora_model .merge_and_unload ()
475604
476- logger .info (f"💾 Saving merged model to: { output_path } " )
605+ logger .info (f"Saving merged model to: { output_path } " )
477606
478607 # Create output directory
479608 os .makedirs (output_path , exist_ok = True )
@@ -496,7 +625,7 @@ def merge_lora_adapter_to_full_model(
496625 json .dump (config , f , indent = 2 )
497626
498627 logger .info (
499- "✅ Updated config.json with correct intent classification label mappings"
628+ "Updated config.json with correct intent classification label mappings"
500629 )
501630
502631 # Copy important files from LoRA adapter
@@ -513,9 +642,9 @@ def merge_lora_adapter_to_full_model(
513642 shutil .copy (
514643 os .path .join (output_path , "label_mapping.json" ), category_mapping_path
515644 )
516- logger .info ("✅ Created category_mapping.json" )
645+ logger .info ("Created category_mapping.json" )
517646
518- logger .info ("✅ LoRA adapter merged successfully!" )
647+ logger .info ("LoRA adapter merged successfully!" )
519648
520649
521650def demo_inference (model_path : str , model_name : str = "modernbert-base" ):
@@ -592,16 +721,11 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
592721 parser .add_argument (
593722 "--model" ,
594723 choices = [
595- "modernbert-base" ,
596- "modernbert-large" ,
597- "bert-base-uncased" ,
598- "bert-large-uncased" ,
599- "roberta-base" ,
600- "roberta-large" ,
601- "deberta-v3-base" ,
602- "deberta-v3-large" ,
724+ "modernbert-base" , # ModernBERT base model - latest architecture
725+ "bert-base-uncased" , # BERT base model - most stable and CPU-friendly
726+ "roberta-base" , # RoBERTa base model - best intent classification performance
603727 ],
604- default = "modernbert -base" ,
728+ default = "bert -base-uncased " ,
605729 )
606730 parser .add_argument ("--lora-rank" , type = int , default = 8 )
607731 parser .add_argument ("--lora-alpha" , type = int , default = 16 )
@@ -610,12 +734,12 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
610734 parser .add_argument ("--alignment-weight" , type = float , default = 0.1 )
611735 parser .add_argument ("--epochs" , type = int , default = 3 )
612736 parser .add_argument ("--batch-size" , type = int , default = 8 )
613- parser .add_argument ("--learning-rate" , type = float , default = 1e-4 )
737+ parser .add_argument ("--learning-rate" , type = float , default = 3e-5 )
614738 parser .add_argument (
615739 "--max-samples" ,
616740 type = int ,
617- default = 1000 ,
618- help = "Maximum samples from MMLU-Pro dataset" ,
741+ default = 5000 ,
742+ help = "Maximum samples from MMLU-Pro dataset (recommended: 5000+ for all 14 categories) " ,
619743 )
620744 parser .add_argument (
621745 "--output-dir" ,
0 commit comments