Skip to content

Commit 56cdc08

Browse files
OneZero-Yrootfs
andauthored
fix: LoRA Model Training Configuration and Data Balance (#233)
* Fix LoRA Model Training Configuration and Data Balance Signed-off-by: OneZero-Y <[email protected]> Fix LoRA Model Training Configuration and Data Balance Signed-off-by: OneZero-Y <[email protected]> * fix:LoRA Model Training Configuration and Data Balance Signed-off-by: OneZero-Y <[email protected]> fix:LoRA Model Training Configuration and Data Balance Signed-off-by: OneZero-Y <[email protected]> --------- Signed-off-by: OneZero-Y <[email protected]> Co-authored-by: Huamin Chen <[email protected]>
1 parent 55792fd commit 56cdc08

File tree

11 files changed

+1212
-388
lines changed

11 files changed

+1212
-388
lines changed

src/training/training_lora/OWNER

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# lora training owners
2+
@OneZero-Y

src/training/training_lora/classifier_model_fine_tuning_lora/ft_linear_lora.py

Lines changed: 166 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,24 @@
100100
# Setup logging
101101
logger = setup_logging()
102102

103+
# Required categories to match legacy model (14 categories)
104+
REQUIRED_CATEGORIES = [
105+
"biology",
106+
"business",
107+
"chemistry",
108+
"computer science",
109+
"economics",
110+
"engineering",
111+
"health",
112+
"history",
113+
"law",
114+
"math",
115+
"other",
116+
"philosophy",
117+
"physics",
118+
"psychology",
119+
]
120+
103121

104122
def create_tokenizer_for_model(model_path: str, base_model_name: str = None):
105123
"""
@@ -135,7 +153,7 @@ def __init__(self, dataset_name="TIGER-Lab/MMLU-Pro"):
135153
self.id2label = {}
136154

137155
def load_huggingface_dataset(self, max_samples=1000):
138-
"""Load the MMLU-Pro dataset from HuggingFace."""
156+
"""Load the MMLU-Pro dataset from HuggingFace with balanced category sampling."""
139157
logger.info(f"Loading dataset from HuggingFace: {self.dataset_name}")
140158

141159
try:
@@ -145,17 +163,103 @@ def load_huggingface_dataset(self, max_samples=1000):
145163

146164
# Extract questions and categories from the test split
147165
# Note: MMLU-Pro typically uses 'test' split for training data
148-
texts = dataset["test"]["question"]
149-
labels = dataset["test"]["category"]
166+
all_texts = dataset["test"]["question"]
167+
all_labels = dataset["test"]["category"]
168+
169+
logger.info(f"Total samples in dataset: {len(all_texts)}")
170+
171+
# Group samples by category
172+
category_samples = {}
173+
for text, label in zip(all_texts, all_labels):
174+
if label not in category_samples:
175+
category_samples[label] = []
176+
category_samples[label].append(text)
177+
178+
logger.info(
179+
f"Available categories in dataset: {sorted(category_samples.keys())}"
180+
)
181+
logger.info(f"Required categories: {REQUIRED_CATEGORIES}")
182+
183+
# Check which required categories are missing
184+
missing_categories = set(REQUIRED_CATEGORIES) - set(category_samples.keys())
185+
if missing_categories:
186+
logger.warning(f"Missing categories in dataset: {missing_categories}")
187+
188+
# Calculate samples per category for balanced sampling
189+
available_required_categories = [
190+
cat for cat in REQUIRED_CATEGORIES if cat in category_samples
191+
]
150192

151-
# Limit samples for faster training
152-
if max_samples and len(texts) > max_samples:
153-
texts = texts[:max_samples]
154-
labels = labels[:max_samples]
155-
logger.info(f"Limited dataset to {max_samples} samples")
193+
# Ensure minimum samples per category for stable training
194+
min_samples_per_category = max(
195+
50, max_samples // (len(available_required_categories) * 2)
196+
)
197+
target_samples_per_category = max_samples // len(
198+
available_required_categories
199+
)
156200

157-
logger.info(f"Loaded {len(texts)} samples")
158-
return texts, labels
201+
logger.info(f"Available categories: {len(available_required_categories)}")
202+
logger.info(f"Min samples per category: {min_samples_per_category}")
203+
logger.info(f"Target samples per category: {target_samples_per_category}")
204+
205+
# Collect balanced samples from required categories with improved strategy
206+
filtered_texts = []
207+
filtered_labels = []
208+
category_counts = {}
209+
insufficient_categories = []
210+
211+
# First pass: collect available samples for each category
212+
for category in available_required_categories:
213+
if category in category_samples:
214+
available_samples = len(category_samples[category])
215+
216+
if available_samples < min_samples_per_category:
217+
insufficient_categories.append(category)
218+
samples_to_take = available_samples # Take all available
219+
else:
220+
samples_to_take = min(
221+
target_samples_per_category, available_samples
222+
)
223+
224+
category_texts = category_samples[category][:samples_to_take]
225+
filtered_texts.extend(category_texts)
226+
filtered_labels.extend([category] * len(category_texts))
227+
category_counts[category] = len(category_texts)
228+
229+
# Log insufficient categories
230+
if insufficient_categories:
231+
logger.warning(
232+
f"Categories with insufficient samples: {insufficient_categories}"
233+
)
234+
for cat in insufficient_categories:
235+
logger.warning(
236+
f" {cat}: only {category_counts.get(cat, 0)} samples available"
237+
)
238+
239+
logger.info(f"Final category distribution: {category_counts}")
240+
logger.info(f"Total filtered samples: {len(filtered_texts)}")
241+
242+
# Ensure we have samples for all required categories
243+
missing_categories = set(available_required_categories) - set(
244+
category_counts.keys()
245+
)
246+
if missing_categories:
247+
logger.error(
248+
f"CRITICAL: Categories with no samples: {missing_categories}"
249+
)
250+
251+
# Validate minimum category coverage
252+
if (
253+
len(category_counts) < len(REQUIRED_CATEGORIES) * 0.8
254+
): # At least 80% of categories
255+
logger.error(
256+
f"CRITICAL: Only {len(category_counts)}/{len(REQUIRED_CATEGORIES)} categories have samples!"
257+
)
258+
logger.error(
259+
"This will result in poor model performance. Consider increasing max_samples or using a different dataset."
260+
)
261+
262+
return filtered_texts, filtered_labels
159263

160264
except Exception as e:
161265
logger.error(f"Error loading dataset: {e}")
@@ -167,12 +271,20 @@ def prepare_datasets(self, max_samples=1000):
167271
# Load the dataset
168272
texts, labels = self.load_huggingface_dataset(max_samples)
169273

170-
# Create label mapping
274+
# Create label mapping using required categories order for consistency
171275
unique_labels = sorted(list(set(labels)))
172-
self.label2id = {label: idx for idx, label in enumerate(unique_labels)}
276+
277+
# Ensure we use the same order as legacy model for consistency
278+
ordered_labels = [cat for cat in REQUIRED_CATEGORIES if cat in unique_labels]
279+
# Add any extra categories that might exist
280+
extra_labels = [cat for cat in unique_labels if cat not in REQUIRED_CATEGORIES]
281+
final_labels = ordered_labels + sorted(extra_labels)
282+
283+
self.label2id = {label: idx for idx, label in enumerate(final_labels)}
173284
self.id2label = {idx: label for label, idx in self.label2id.items()}
174285

175-
logger.info(f"Found {len(unique_labels)} unique categories: {unique_labels}")
286+
logger.info(f"Found {len(final_labels)} unique categories: {final_labels}")
287+
logger.info(f"Label mapping: {self.label2id}")
176288

177289
# Convert labels to IDs
178290
label_ids = [self.label2id[label] for label in labels]
@@ -245,9 +357,20 @@ def compute_loss(
245357
logits.view(-1, self.model.config.num_labels), labels.view(-1)
246358
)
247359

248-
# TODO: Add feature alignment loss when original model is available
360+
# Feature alignment loss to improve LoRA adaptation
249361
total_loss = classification_loss
250362

363+
if self.enable_feature_alignment:
364+
# Add L2 regularization on LoRA parameters to prevent overfitting
365+
l2_reg = 0.0
366+
for name, param in model.named_parameters():
367+
if "lora_" in name and param.requires_grad:
368+
l2_reg += torch.norm(param, p=2)
369+
370+
# Add feature alignment loss
371+
alignment_loss = self.alignment_weight * l2_reg
372+
total_loss = classification_loss + alignment_loss
373+
251374
return (total_loss, outputs) if return_outputs else total_loss
252375

253376

@@ -321,7 +444,7 @@ def main(
321444
lora_dropout: float = 0.1,
322445
num_epochs: int = 3,
323446
batch_size: int = 8,
324-
learning_rate: float = 1e-4,
447+
learning_rate: float = 3e-5, # Reduced from 1e-4 to prevent gradient explosion
325448
max_samples: int = 1000,
326449
output_dir: str = None,
327450
enable_feature_alignment: bool = False,
@@ -370,13 +493,12 @@ def main(
370493

371494
logger.info(f"Model will be saved to: {output_dir}")
372495

373-
# Training arguments
496+
# Training arguments optimized for LoRA sequence classification based on PEFT best practices
374497
training_args = TrainingArguments(
375498
output_dir=output_dir,
376499
num_train_epochs=num_epochs,
377500
per_device_train_batch_size=batch_size,
378501
per_device_eval_batch_size=batch_size,
379-
warmup_steps=100,
380502
weight_decay=0.01,
381503
logging_dir=f"{output_dir}/logs",
382504
logging_steps=10,
@@ -386,6 +508,13 @@ def main(
386508
metric_for_best_model="eval_f1",
387509
greater_is_better=True,
388510
learning_rate=learning_rate,
511+
# PEFT optimization: Enhanced stability measures
512+
max_grad_norm=1.0, # Gradient clipping to prevent explosion
513+
lr_scheduler_type="cosine", # More stable learning rate schedule for LoRA
514+
warmup_ratio=0.06, # PEFT recommended warmup ratio for sequence classification
515+
# Additional stability measures for intent classification
516+
dataloader_drop_last=False,
517+
eval_accumulation_steps=1,
389518
)
390519

391520
# Create trainer
@@ -419,18 +548,18 @@ def main(
419548
json.dump(label_mapping, f, indent=2)
420549

421550
logger.info(f"LoRA intent classification model saved to: {output_dir}")
422-
logger.info("Saved both label_mapping.json and category_mapping.json")
551+
logger.info("Saved both label_mapping.json and category_mapping.json")
423552

424553
# Auto-merge LoRA adapter with base model for Rust compatibility
425-
logger.info("🔄 Auto-merging LoRA adapter with base model for Rust inference...")
554+
logger.info("Auto-merging LoRA adapter with base model for Rust inference...")
426555
try:
427556
merged_output_dir = f"{output_dir}_rust"
428557
merge_lora_adapter_to_full_model(output_dir, merged_output_dir, model_path)
429-
logger.info(f"Rust-compatible model saved to: {merged_output_dir}")
430-
logger.info(f" This model can be used with Rust candle-binding!")
558+
logger.info(f"Rust-compatible model saved to: {merged_output_dir}")
559+
logger.info(f"This model can be used with Rust candle-binding!")
431560
except Exception as e:
432-
logger.warning(f"⚠️ Auto-merge failed: {e}")
433-
logger.info(f" You can manually merge using a merge script")
561+
logger.warning(f"Auto-merge failed: {e}")
562+
logger.info(f"You can manually merge using a merge script")
434563

435564
# Final evaluation
436565
logger.info("Final evaluation on validation set...")
@@ -448,7 +577,7 @@ def merge_lora_adapter_to_full_model(
448577
This function is automatically called after training to generate Rust-compatible models.
449578
"""
450579

451-
logger.info(f"🔄 Loading base model: {base_model_path}")
580+
logger.info(f"Loading base model: {base_model_path}")
452581

453582
# Load label mapping to get correct number of labels
454583
with open(os.path.join(lora_adapter_path, "label_mapping.json"), "r") as f:
@@ -463,17 +592,17 @@ def merge_lora_adapter_to_full_model(
463592
# Load tokenizer with model-specific configuration
464593
tokenizer = create_tokenizer_for_model(base_model_path, base_model_path)
465594

466-
logger.info(f"🔄 Loading LoRA adapter from: {lora_adapter_path}")
595+
logger.info(f"Loading LoRA adapter from: {lora_adapter_path}")
467596

468597
# Load LoRA model
469598
lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
470599

471-
logger.info("🔄 Merging LoRA adapter with base model...")
600+
logger.info("Merging LoRA adapter with base model...")
472601

473602
# Merge and unload LoRA
474603
merged_model = lora_model.merge_and_unload()
475604

476-
logger.info(f"💾 Saving merged model to: {output_path}")
605+
logger.info(f"Saving merged model to: {output_path}")
477606

478607
# Create output directory
479608
os.makedirs(output_path, exist_ok=True)
@@ -496,7 +625,7 @@ def merge_lora_adapter_to_full_model(
496625
json.dump(config, f, indent=2)
497626

498627
logger.info(
499-
"Updated config.json with correct intent classification label mappings"
628+
"Updated config.json with correct intent classification label mappings"
500629
)
501630

502631
# Copy important files from LoRA adapter
@@ -513,9 +642,9 @@ def merge_lora_adapter_to_full_model(
513642
shutil.copy(
514643
os.path.join(output_path, "label_mapping.json"), category_mapping_path
515644
)
516-
logger.info("Created category_mapping.json")
645+
logger.info("Created category_mapping.json")
517646

518-
logger.info("LoRA adapter merged successfully!")
647+
logger.info("LoRA adapter merged successfully!")
519648

520649

521650
def demo_inference(model_path: str, model_name: str = "modernbert-base"):
@@ -592,16 +721,11 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
592721
parser.add_argument(
593722
"--model",
594723
choices=[
595-
"modernbert-base",
596-
"modernbert-large",
597-
"bert-base-uncased",
598-
"bert-large-uncased",
599-
"roberta-base",
600-
"roberta-large",
601-
"deberta-v3-base",
602-
"deberta-v3-large",
724+
"modernbert-base", # ModernBERT base model - latest architecture
725+
"bert-base-uncased", # BERT base model - most stable and CPU-friendly
726+
"roberta-base", # RoBERTa base model - best intent classification performance
603727
],
604-
default="modernbert-base",
728+
default="bert-base-uncased",
605729
)
606730
parser.add_argument("--lora-rank", type=int, default=8)
607731
parser.add_argument("--lora-alpha", type=int, default=16)
@@ -610,12 +734,12 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
610734
parser.add_argument("--alignment-weight", type=float, default=0.1)
611735
parser.add_argument("--epochs", type=int, default=3)
612736
parser.add_argument("--batch-size", type=int, default=8)
613-
parser.add_argument("--learning-rate", type=float, default=1e-4)
737+
parser.add_argument("--learning-rate", type=float, default=3e-5)
614738
parser.add_argument(
615739
"--max-samples",
616740
type=int,
617-
default=1000,
618-
help="Maximum samples from MMLU-Pro dataset",
741+
default=5000,
742+
help="Maximum samples from MMLU-Pro dataset (recommended: 5000+ for all 14 categories)",
619743
)
620744
parser.add_argument(
621745
"--output-dir",

0 commit comments

Comments
 (0)