100
100
# Setup logging
101
101
logger = setup_logging ()
102
102
103
+ # Required categories to match legacy model (14 categories)
104
+ REQUIRED_CATEGORIES = [
105
+ "biology" ,
106
+ "business" ,
107
+ "chemistry" ,
108
+ "computer science" ,
109
+ "economics" ,
110
+ "engineering" ,
111
+ "health" ,
112
+ "history" ,
113
+ "law" ,
114
+ "math" ,
115
+ "other" ,
116
+ "philosophy" ,
117
+ "physics" ,
118
+ "psychology" ,
119
+ ]
120
+
103
121
104
122
def create_tokenizer_for_model (model_path : str , base_model_name : str = None ):
105
123
"""
@@ -135,7 +153,7 @@ def __init__(self, dataset_name="TIGER-Lab/MMLU-Pro"):
135
153
self .id2label = {}
136
154
137
155
def load_huggingface_dataset (self , max_samples = 1000 ):
138
- """Load the MMLU-Pro dataset from HuggingFace."""
156
+ """Load the MMLU-Pro dataset from HuggingFace with balanced category sampling ."""
139
157
logger .info (f"Loading dataset from HuggingFace: { self .dataset_name } " )
140
158
141
159
try :
@@ -145,17 +163,103 @@ def load_huggingface_dataset(self, max_samples=1000):
145
163
146
164
# Extract questions and categories from the test split
147
165
# Note: MMLU-Pro typically uses 'test' split for training data
148
- texts = dataset ["test" ]["question" ]
149
- labels = dataset ["test" ]["category" ]
166
+ all_texts = dataset ["test" ]["question" ]
167
+ all_labels = dataset ["test" ]["category" ]
168
+
169
+ logger .info (f"Total samples in dataset: { len (all_texts )} " )
170
+
171
+ # Group samples by category
172
+ category_samples = {}
173
+ for text , label in zip (all_texts , all_labels ):
174
+ if label not in category_samples :
175
+ category_samples [label ] = []
176
+ category_samples [label ].append (text )
177
+
178
+ logger .info (
179
+ f"Available categories in dataset: { sorted (category_samples .keys ())} "
180
+ )
181
+ logger .info (f"Required categories: { REQUIRED_CATEGORIES } " )
182
+
183
+ # Check which required categories are missing
184
+ missing_categories = set (REQUIRED_CATEGORIES ) - set (category_samples .keys ())
185
+ if missing_categories :
186
+ logger .warning (f"Missing categories in dataset: { missing_categories } " )
187
+
188
+ # Calculate samples per category for balanced sampling
189
+ available_required_categories = [
190
+ cat for cat in REQUIRED_CATEGORIES if cat in category_samples
191
+ ]
150
192
151
- # Limit samples for faster training
152
- if max_samples and len (texts ) > max_samples :
153
- texts = texts [:max_samples ]
154
- labels = labels [:max_samples ]
155
- logger .info (f"Limited dataset to { max_samples } samples" )
193
+ # Ensure minimum samples per category for stable training
194
+ min_samples_per_category = max (
195
+ 50 , max_samples // (len (available_required_categories ) * 2 )
196
+ )
197
+ target_samples_per_category = max_samples // len (
198
+ available_required_categories
199
+ )
156
200
157
- logger .info (f"Loaded { len (texts )} samples" )
158
- return texts , labels
201
+ logger .info (f"Available categories: { len (available_required_categories )} " )
202
+ logger .info (f"Min samples per category: { min_samples_per_category } " )
203
+ logger .info (f"Target samples per category: { target_samples_per_category } " )
204
+
205
+ # Collect balanced samples from required categories with improved strategy
206
+ filtered_texts = []
207
+ filtered_labels = []
208
+ category_counts = {}
209
+ insufficient_categories = []
210
+
211
+ # First pass: collect available samples for each category
212
+ for category in available_required_categories :
213
+ if category in category_samples :
214
+ available_samples = len (category_samples [category ])
215
+
216
+ if available_samples < min_samples_per_category :
217
+ insufficient_categories .append (category )
218
+ samples_to_take = available_samples # Take all available
219
+ else :
220
+ samples_to_take = min (
221
+ target_samples_per_category , available_samples
222
+ )
223
+
224
+ category_texts = category_samples [category ][:samples_to_take ]
225
+ filtered_texts .extend (category_texts )
226
+ filtered_labels .extend ([category ] * len (category_texts ))
227
+ category_counts [category ] = len (category_texts )
228
+
229
+ # Log insufficient categories
230
+ if insufficient_categories :
231
+ logger .warning (
232
+ f"Categories with insufficient samples: { insufficient_categories } "
233
+ )
234
+ for cat in insufficient_categories :
235
+ logger .warning (
236
+ f" { cat } : only { category_counts .get (cat , 0 )} samples available"
237
+ )
238
+
239
+ logger .info (f"Final category distribution: { category_counts } " )
240
+ logger .info (f"Total filtered samples: { len (filtered_texts )} " )
241
+
242
+ # Ensure we have samples for all required categories
243
+ missing_categories = set (available_required_categories ) - set (
244
+ category_counts .keys ()
245
+ )
246
+ if missing_categories :
247
+ logger .error (
248
+ f"CRITICAL: Categories with no samples: { missing_categories } "
249
+ )
250
+
251
+ # Validate minimum category coverage
252
+ if (
253
+ len (category_counts ) < len (REQUIRED_CATEGORIES ) * 0.8
254
+ ): # At least 80% of categories
255
+ logger .error (
256
+ f"CRITICAL: Only { len (category_counts )} /{ len (REQUIRED_CATEGORIES )} categories have samples!"
257
+ )
258
+ logger .error (
259
+ "This will result in poor model performance. Consider increasing max_samples or using a different dataset."
260
+ )
261
+
262
+ return filtered_texts , filtered_labels
159
263
160
264
except Exception as e :
161
265
logger .error (f"Error loading dataset: { e } " )
@@ -167,12 +271,20 @@ def prepare_datasets(self, max_samples=1000):
167
271
# Load the dataset
168
272
texts , labels = self .load_huggingface_dataset (max_samples )
169
273
170
- # Create label mapping
274
+ # Create label mapping using required categories order for consistency
171
275
unique_labels = sorted (list (set (labels )))
172
- self .label2id = {label : idx for idx , label in enumerate (unique_labels )}
276
+
277
+ # Ensure we use the same order as legacy model for consistency
278
+ ordered_labels = [cat for cat in REQUIRED_CATEGORIES if cat in unique_labels ]
279
+ # Add any extra categories that might exist
280
+ extra_labels = [cat for cat in unique_labels if cat not in REQUIRED_CATEGORIES ]
281
+ final_labels = ordered_labels + sorted (extra_labels )
282
+
283
+ self .label2id = {label : idx for idx , label in enumerate (final_labels )}
173
284
self .id2label = {idx : label for label , idx in self .label2id .items ()}
174
285
175
- logger .info (f"Found { len (unique_labels )} unique categories: { unique_labels } " )
286
+ logger .info (f"Found { len (final_labels )} unique categories: { final_labels } " )
287
+ logger .info (f"Label mapping: { self .label2id } " )
176
288
177
289
# Convert labels to IDs
178
290
label_ids = [self .label2id [label ] for label in labels ]
@@ -245,9 +357,20 @@ def compute_loss(
245
357
logits .view (- 1 , self .model .config .num_labels ), labels .view (- 1 )
246
358
)
247
359
248
- # TODO: Add feature alignment loss when original model is available
360
+ # Feature alignment loss to improve LoRA adaptation
249
361
total_loss = classification_loss
250
362
363
+ if self .enable_feature_alignment :
364
+ # Add L2 regularization on LoRA parameters to prevent overfitting
365
+ l2_reg = 0.0
366
+ for name , param in model .named_parameters ():
367
+ if "lora_" in name and param .requires_grad :
368
+ l2_reg += torch .norm (param , p = 2 )
369
+
370
+ # Add feature alignment loss
371
+ alignment_loss = self .alignment_weight * l2_reg
372
+ total_loss = classification_loss + alignment_loss
373
+
251
374
return (total_loss , outputs ) if return_outputs else total_loss
252
375
253
376
@@ -321,7 +444,7 @@ def main(
321
444
lora_dropout : float = 0.1 ,
322
445
num_epochs : int = 3 ,
323
446
batch_size : int = 8 ,
324
- learning_rate : float = 1e-4 ,
447
+ learning_rate : float = 3e-5 , # Reduced from 1e-4 to prevent gradient explosion
325
448
max_samples : int = 1000 ,
326
449
output_dir : str = None ,
327
450
enable_feature_alignment : bool = False ,
@@ -370,13 +493,12 @@ def main(
370
493
371
494
logger .info (f"Model will be saved to: { output_dir } " )
372
495
373
- # Training arguments
496
+ # Training arguments optimized for LoRA sequence classification based on PEFT best practices
374
497
training_args = TrainingArguments (
375
498
output_dir = output_dir ,
376
499
num_train_epochs = num_epochs ,
377
500
per_device_train_batch_size = batch_size ,
378
501
per_device_eval_batch_size = batch_size ,
379
- warmup_steps = 100 ,
380
502
weight_decay = 0.01 ,
381
503
logging_dir = f"{ output_dir } /logs" ,
382
504
logging_steps = 10 ,
@@ -386,6 +508,13 @@ def main(
386
508
metric_for_best_model = "eval_f1" ,
387
509
greater_is_better = True ,
388
510
learning_rate = learning_rate ,
511
+ # PEFT optimization: Enhanced stability measures
512
+ max_grad_norm = 1.0 , # Gradient clipping to prevent explosion
513
+ lr_scheduler_type = "cosine" , # More stable learning rate schedule for LoRA
514
+ warmup_ratio = 0.06 , # PEFT recommended warmup ratio for sequence classification
515
+ # Additional stability measures for intent classification
516
+ dataloader_drop_last = False ,
517
+ eval_accumulation_steps = 1 ,
389
518
)
390
519
391
520
# Create trainer
@@ -419,18 +548,18 @@ def main(
419
548
json .dump (label_mapping , f , indent = 2 )
420
549
421
550
logger .info (f"LoRA intent classification model saved to: { output_dir } " )
422
- logger .info ("✅ Saved both label_mapping.json and category_mapping.json" )
551
+ logger .info ("Saved both label_mapping.json and category_mapping.json" )
423
552
424
553
# Auto-merge LoRA adapter with base model for Rust compatibility
425
- logger .info ("🔄 Auto-merging LoRA adapter with base model for Rust inference..." )
554
+ logger .info ("Auto-merging LoRA adapter with base model for Rust inference..." )
426
555
try :
427
556
merged_output_dir = f"{ output_dir } _rust"
428
557
merge_lora_adapter_to_full_model (output_dir , merged_output_dir , model_path )
429
- logger .info (f"✅ Rust-compatible model saved to: { merged_output_dir } " )
430
- logger .info (f" This model can be used with Rust candle-binding!" )
558
+ logger .info (f"Rust-compatible model saved to: { merged_output_dir } " )
559
+ logger .info (f"This model can be used with Rust candle-binding!" )
431
560
except Exception as e :
432
- logger .warning (f"⚠️ Auto-merge failed: { e } " )
433
- logger .info (f" You can manually merge using a merge script" )
561
+ logger .warning (f"Auto-merge failed: { e } " )
562
+ logger .info (f"You can manually merge using a merge script" )
434
563
435
564
# Final evaluation
436
565
logger .info ("Final evaluation on validation set..." )
@@ -448,7 +577,7 @@ def merge_lora_adapter_to_full_model(
448
577
This function is automatically called after training to generate Rust-compatible models.
449
578
"""
450
579
451
- logger .info (f"🔄 Loading base model: { base_model_path } " )
580
+ logger .info (f"Loading base model: { base_model_path } " )
452
581
453
582
# Load label mapping to get correct number of labels
454
583
with open (os .path .join (lora_adapter_path , "label_mapping.json" ), "r" ) as f :
@@ -463,17 +592,17 @@ def merge_lora_adapter_to_full_model(
463
592
# Load tokenizer with model-specific configuration
464
593
tokenizer = create_tokenizer_for_model (base_model_path , base_model_path )
465
594
466
- logger .info (f"🔄 Loading LoRA adapter from: { lora_adapter_path } " )
595
+ logger .info (f"Loading LoRA adapter from: { lora_adapter_path } " )
467
596
468
597
# Load LoRA model
469
598
lora_model = PeftModel .from_pretrained (base_model , lora_adapter_path )
470
599
471
- logger .info ("🔄 Merging LoRA adapter with base model..." )
600
+ logger .info ("Merging LoRA adapter with base model..." )
472
601
473
602
# Merge and unload LoRA
474
603
merged_model = lora_model .merge_and_unload ()
475
604
476
- logger .info (f"💾 Saving merged model to: { output_path } " )
605
+ logger .info (f"Saving merged model to: { output_path } " )
477
606
478
607
# Create output directory
479
608
os .makedirs (output_path , exist_ok = True )
@@ -496,7 +625,7 @@ def merge_lora_adapter_to_full_model(
496
625
json .dump (config , f , indent = 2 )
497
626
498
627
logger .info (
499
- "✅ Updated config.json with correct intent classification label mappings"
628
+ "Updated config.json with correct intent classification label mappings"
500
629
)
501
630
502
631
# Copy important files from LoRA adapter
@@ -513,9 +642,9 @@ def merge_lora_adapter_to_full_model(
513
642
shutil .copy (
514
643
os .path .join (output_path , "label_mapping.json" ), category_mapping_path
515
644
)
516
- logger .info ("✅ Created category_mapping.json" )
645
+ logger .info ("Created category_mapping.json" )
517
646
518
- logger .info ("✅ LoRA adapter merged successfully!" )
647
+ logger .info ("LoRA adapter merged successfully!" )
519
648
520
649
521
650
def demo_inference (model_path : str , model_name : str = "modernbert-base" ):
@@ -592,16 +721,11 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
592
721
parser .add_argument (
593
722
"--model" ,
594
723
choices = [
595
- "modernbert-base" ,
596
- "modernbert-large" ,
597
- "bert-base-uncased" ,
598
- "bert-large-uncased" ,
599
- "roberta-base" ,
600
- "roberta-large" ,
601
- "deberta-v3-base" ,
602
- "deberta-v3-large" ,
724
+ "modernbert-base" , # ModernBERT base model - latest architecture
725
+ "bert-base-uncased" , # BERT base model - most stable and CPU-friendly
726
+ "roberta-base" , # RoBERTa base model - best intent classification performance
603
727
],
604
- default = "modernbert -base" ,
728
+ default = "bert -base-uncased " ,
605
729
)
606
730
parser .add_argument ("--lora-rank" , type = int , default = 8 )
607
731
parser .add_argument ("--lora-alpha" , type = int , default = 16 )
@@ -610,12 +734,12 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
610
734
parser .add_argument ("--alignment-weight" , type = float , default = 0.1 )
611
735
parser .add_argument ("--epochs" , type = int , default = 3 )
612
736
parser .add_argument ("--batch-size" , type = int , default = 8 )
613
- parser .add_argument ("--learning-rate" , type = float , default = 1e-4 )
737
+ parser .add_argument ("--learning-rate" , type = float , default = 3e-5 )
614
738
parser .add_argument (
615
739
"--max-samples" ,
616
740
type = int ,
617
- default = 1000 ,
618
- help = "Maximum samples from MMLU-Pro dataset" ,
741
+ default = 5000 ,
742
+ help = "Maximum samples from MMLU-Pro dataset (recommended: 5000+ for all 14 categories) " ,
619
743
)
620
744
parser .add_argument (
621
745
"--output-dir" ,
0 commit comments