Skip to content

Commit 17c38ae

Browse files
committed
fix:LoRA Model Training Configuration and Data Balance
Signed-off-by: OneZero-Y <[email protected]>
1 parent cb3d2d5 commit 17c38ae

File tree

10 files changed

+621
-76
lines changed

10 files changed

+621
-76
lines changed

src/training/training_lora/OWNER

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# lora training owners
2+
@OneZero-Y

src/training/training_lora/classifier_model_fine_tuning_lora/ft_linear_lora.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def main(
444444
lora_dropout: float = 0.1,
445445
num_epochs: int = 3,
446446
batch_size: int = 8,
447-
learning_rate: float = 1e-4,
447+
learning_rate: float = 3e-5, # Reduced from 1e-4 to prevent gradient explosion
448448
max_samples: int = 1000,
449449
output_dir: str = None,
450450
enable_feature_alignment: bool = False,
@@ -493,13 +493,12 @@ def main(
493493

494494
logger.info(f"Model will be saved to: {output_dir}")
495495

496-
# Training arguments
496+
# Training arguments optimized for LoRA sequence classification based on PEFT best practices
497497
training_args = TrainingArguments(
498498
output_dir=output_dir,
499499
num_train_epochs=num_epochs,
500500
per_device_train_batch_size=batch_size,
501501
per_device_eval_batch_size=batch_size,
502-
warmup_steps=100,
503502
weight_decay=0.01,
504503
logging_dir=f"{output_dir}/logs",
505504
logging_steps=10,
@@ -509,6 +508,13 @@ def main(
509508
metric_for_best_model="eval_f1",
510509
greater_is_better=True,
511510
learning_rate=learning_rate,
511+
# PEFT optimization: Enhanced stability measures
512+
max_grad_norm=1.0, # Gradient clipping to prevent explosion
513+
lr_scheduler_type="cosine", # More stable learning rate schedule for LoRA
514+
warmup_ratio=0.06, # PEFT recommended warmup ratio for sequence classification
515+
# Additional stability measures for intent classification
516+
dataloader_drop_last=False,
517+
eval_accumulation_steps=1,
512518
)
513519

514520
# Create trainer
@@ -728,7 +734,7 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
728734
parser.add_argument("--alignment-weight", type=float, default=0.1)
729735
parser.add_argument("--epochs", type=int, default=3)
730736
parser.add_argument("--batch-size", type=int, default=8)
731-
parser.add_argument("--learning-rate", type=float, default=1e-4)
737+
parser.add_argument("--learning-rate", type=float, default=3e-5)
732738
parser.add_argument(
733739
"--max-samples",
734740
type=int,

src/training/training_lora/classifier_model_fine_tuning_lora/ft_linear_lora_verifier.go

Lines changed: 174 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"log"
99
"os"
1010
"path/filepath"
11+
"strings"
1112

1213
candle "github.com/vllm-project/semantic-router/candle-binding"
1314
)
@@ -209,34 +210,192 @@ func main() {
209210
log.Fatalf("Failed to initialize LoRA classifier: %v", err)
210211
}
211212

212-
// Test samples for intent classification (matching Python demo_inference)
213-
testSamples := []string{
214-
"What is the best strategy for corporate mergers and acquisitions?",
215-
"How do antitrust laws affect business competition?",
216-
"What are the psychological factors that influence consumer behavior?",
217-
"Explain the legal requirements for contract formation",
218-
"What is the difference between civil and criminal law?",
219-
"How does cognitive bias affect decision making?",
213+
// Test samples with expected intent categories for validation
214+
testSamples := []struct {
215+
text string
216+
description string
217+
expected string
218+
}{
219+
{
220+
"What is the best strategy for corporate mergers and acquisitions?",
221+
"Business strategy question",
222+
"business",
223+
},
224+
{
225+
"How do antitrust laws affect business competition?",
226+
"Business law question",
227+
"business",
228+
},
229+
{
230+
"What are the psychological factors that influence consumer behavior?",
231+
"Psychology and behavior question",
232+
"psychology",
233+
},
234+
{
235+
"Explain the legal requirements for contract formation",
236+
"Legal concepts question",
237+
"jurisprudence",
238+
},
239+
{
240+
"What is the difference between civil and criminal law?",
241+
"Legal system question",
242+
"jurisprudence",
243+
},
244+
{
245+
"How does cognitive bias affect decision making?",
246+
"Psychology and cognition question",
247+
"psychology",
248+
},
249+
{
250+
"What is the derivative of e^x?",
251+
"Mathematical calculus question",
252+
"mathematics",
253+
},
254+
{
255+
"Explain the concept of supply and demand in economics.",
256+
"Economic principles question",
257+
"economics",
258+
},
259+
{
260+
"How does DNA replication work in eukaryotic cells?",
261+
"Biology and genetics question",
262+
"biology",
263+
},
264+
{
265+
"What is the difference between a civil law and common law system?",
266+
"Legal systems comparison",
267+
"jurisprudence",
268+
},
269+
{
270+
"Explain how transistors work in computer processors.",
271+
"Computer engineering question",
272+
"computer_science",
273+
},
274+
{
275+
"Why do stars twinkle?",
276+
"Astronomical physics question",
277+
"physics",
278+
},
279+
{
280+
"How do I create a balanced portfolio for retirement?",
281+
"Financial planning question",
282+
"economics",
283+
},
284+
{
285+
"What causes mental illnesses?",
286+
"Mental health and psychology question",
287+
"psychology",
288+
},
289+
{
290+
"How do computer algorithms work?",
291+
"Computer science fundamentals",
292+
"computer_science",
293+
},
294+
{
295+
"Explain the historical significance of the Roman Empire.",
296+
"Historical analysis question",
297+
"history",
298+
},
299+
{
300+
"What is the derivative of f(x) = x^3 + 2x^2 - 5x + 7?",
301+
"Calculus problem",
302+
"mathematics",
303+
},
304+
{
305+
"Describe the process of photosynthesis in plants.",
306+
"Biological processes question",
307+
"biology",
308+
},
309+
{
310+
"What are the principles of macroeconomic policy?",
311+
"Economic policy question",
312+
"economics",
313+
},
314+
{
315+
"How does machine learning classification work?",
316+
"Machine learning concepts",
317+
"computer_science",
318+
},
319+
{
320+
"What is the capital of France?",
321+
"General knowledge question",
322+
"other",
323+
},
220324
}
221325

222326
fmt.Println("\nTesting LoRA Intent Classification:")
223327
fmt.Println("===================================")
224328

225-
for i, sample := range testSamples {
226-
fmt.Printf("\nTest %d: %s\n", i+1, sample)
329+
// Statistics tracking
330+
var (
331+
totalTests = len(testSamples)
332+
correctTests = 0
333+
highConfidence = 0
334+
lowConfidence = 0
335+
)
336+
337+
for i, test := range testSamples {
338+
fmt.Printf("\nTest %d: %s\n", i+1, test.description)
339+
fmt.Printf(" Text: \"%s\"\n", test.text)
227340

228-
result, err := classifyIntentText(sample, config)
341+
result, err := classifyIntentText(test.text, config)
229342
if err != nil {
230-
fmt.Printf("Error: %v\n", err)
343+
fmt.Printf(" Classification failed: %v\n", err)
231344
continue
232345
}
233346

347+
// Get the predicted label name
348+
labelName := "unknown"
234349
if label, exists := categoryLabels[result.Class]; exists {
235-
fmt.Printf("Classification: %s (Class ID: %d, Confidence: %.4f)\n", label, result.Class, result.Confidence)
350+
labelName = label
351+
}
352+
353+
// Print the result
354+
fmt.Printf(" Classified as: %s (Class ID: %d, Confidence: %.4f)\n",
355+
labelName, result.Class, result.Confidence)
356+
357+
// Check correctness
358+
isCorrect := labelName == test.expected
359+
if isCorrect {
360+
fmt.Printf(" ✓ CORRECT")
361+
correctTests++
236362
} else {
237-
fmt.Printf("Unknown category index: %d (Confidence: %.4f)\n", result.Class, result.Confidence)
363+
fmt.Printf(" ✗ INCORRECT (Expected: %s)", test.expected)
238364
}
365+
366+
// Add confidence assessment
367+
if result.Confidence > 0.7 {
368+
fmt.Printf(" - HIGH CONFIDENCE\n")
369+
highConfidence++
370+
} else if result.Confidence > 0.5 {
371+
fmt.Printf(" - MEDIUM CONFIDENCE\n")
372+
} else {
373+
fmt.Printf(" - LOW CONFIDENCE\n")
374+
lowConfidence++
375+
}
376+
}
377+
378+
// Print comprehensive summary
379+
fmt.Println("\n" + strings.Repeat("=", 50))
380+
fmt.Println("INTENT CLASSIFICATION TEST SUMMARY")
381+
fmt.Println(strings.Repeat("=", 50))
382+
fmt.Printf("Total Tests: %d\n", totalTests)
383+
fmt.Printf("Correct Predictions: %d/%d (%.1f%%)\n", correctTests, totalTests, float64(correctTests)/float64(totalTests)*100)
384+
fmt.Printf("High Confidence (>0.7): %d/%d (%.1f%%)\n", highConfidence, totalTests, float64(highConfidence)/float64(totalTests)*100)
385+
fmt.Printf("Low Confidence (<0.5): %d/%d (%.1f%%)\n", lowConfidence, totalTests, float64(lowConfidence)/float64(totalTests)*100)
386+
387+
// Overall assessment
388+
accuracy := float64(correctTests) / float64(totalTests) * 100
389+
fmt.Printf("\nOVERALL ASSESSMENT: ")
390+
if accuracy >= 85.0 {
391+
fmt.Printf("EXCELLENT (%.1f%% accuracy)\n", accuracy)
392+
} else if accuracy >= 70.0 {
393+
fmt.Printf("GOOD (%.1f%% accuracy)\n", accuracy)
394+
} else if accuracy >= 50.0 {
395+
fmt.Printf("FAIR (%.1f%% accuracy) - Consider retraining\n", accuracy)
396+
} else {
397+
fmt.Printf("POOR (%.1f%% accuracy) - Requires retraining\n", accuracy)
239398
}
240399

241-
fmt.Println("\nLoRA Intent Classification test completed!")
400+
fmt.Println("\nLoRA Intent Classification verification completed!")
242401
}

src/training/training_lora/classifier_model_fine_tuning_lora/train_cpu_optimized.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ LORA_RANK=8 # Optimal rank for stability and performance
1717
LORA_ALPHA=16 # Standard alpha (2x rank) for best results
1818
MAX_SAMPLES=7000 # Increased samples for better coverage of 14 categories
1919
BATCH_SIZE=2 # Small batch size for CPU
20-
LEARNING_RATE=1e-4 # Lower learning rate for more stable training
20+
LEARNING_RATE=3e-5 # Optimized learning rate based on PEFT best practices
2121

2222
# CPU-friendly model set (smaller models only)
2323
CPU_MODELS=(
@@ -131,12 +131,14 @@ train_cpu_model() {
131131

132132
# CPU-optimized training command
133133
local cmd="python ft_linear_lora.py \
134+
--mode train \
134135
--model $model_name \
135136
--epochs $EPOCHS \
136-
--max-samples $MAX_SAMPLES \
137137
--lora-rank $LORA_RANK \
138+
--lora-alpha $LORA_ALPHA \
139+
--max-samples $MAX_SAMPLES \
138140
--batch-size $BATCH_SIZE \
139-
--output-dir lora_intent_classifier_${model_name}_r${LORA_RANK}_model"
141+
--learning-rate $LEARNING_RATE"
140142

141143
echo "📝 Command: $cmd"
142144
echo "📋 Log file: $log_file"

src/training/training_lora/pii_model_fine_tuning_lora/pii_bert_finetuning_lora.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def validate_bio_labels(texts, token_labels):
423423

424424
# Show entity statistics
425425
if entity_stats:
426-
logger.info(f"📈 Entity Statistics:")
426+
logger.info(f"Entity Statistics:")
427427
for entity_type, stats in sorted(
428428
entity_stats.items(), key=lambda x: x[1]["count"], reverse=True
429429
)[:5]:
@@ -432,9 +432,9 @@ def validate_bio_labels(texts, token_labels):
432432
)
433433

434434
if bio_violations > 0:
435-
logger.warning(f"⚠️ Found {bio_violations} BIO labeling violations!")
435+
logger.warning(f"Found {bio_violations} BIO labeling violations!")
436436
else:
437-
logger.info("All BIO labels are consistent!")
437+
logger.info("All BIO labels are consistent!")
438438

439439
return {
440440
"total_samples": total_samples,
@@ -447,16 +447,16 @@ def validate_bio_labels(texts, token_labels):
447447

448448
def analyze_data_quality(texts, token_labels, sample_size=5):
449449
"""Analyze and display data quality with sample examples."""
450-
logger.info(f"🔍 Data Quality Analysis:")
450+
logger.info(f"Data Quality Analysis:")
451451

452452
# Show sample examples with their labels
453-
logger.info(f"📝 Sample Examples (showing first {sample_size}):")
453+
logger.info(f"Sample Examples (showing first {sample_size}):")
454454
for i in range(min(sample_size, len(texts))):
455455
tokens = texts[i]
456456
labels = token_labels[i]
457457

458-
logger.info(f" Sample {i+1}:")
459-
logger.info(f" Text: {' '.join(tokens)}")
458+
logger.info(f"Sample {i+1}:")
459+
logger.info(f"Text: {' '.join(tokens)}")
460460

461461
# Show only non-O labels for clarity
462462
entities = []
@@ -633,7 +633,7 @@ def main(
633633
lora_dropout: float = 0.1,
634634
num_epochs: int = 3,
635635
batch_size: int = 8,
636-
learning_rate: float = 1e-4,
636+
learning_rate: float = 3e-5, # Optimized for LoRA based on PEFT best practices
637637
max_samples: int = 1000,
638638
):
639639
"""Main training function for LoRA PII detection."""
@@ -682,13 +682,17 @@ def main(
682682
os.makedirs(output_dir, exist_ok=True)
683683

684684
# Training arguments
685+
# Training arguments optimized for LoRA token classification based on PEFT best practices
685686
training_args = TrainingArguments(
686687
output_dir=output_dir,
687688
num_train_epochs=num_epochs,
688689
per_device_train_batch_size=batch_size,
689690
per_device_eval_batch_size=batch_size,
690691
learning_rate=learning_rate,
691-
warmup_steps=50,
692+
# PEFT optimization: Enhanced stability measures
693+
max_grad_norm=1.0, # Gradient clipping to prevent explosion
694+
lr_scheduler_type="cosine", # More stable learning rate schedule for LoRA
695+
warmup_ratio=0.06, # PEFT recommended warmup ratio for token classification
692696
weight_decay=0.01,
693697
logging_dir=f"{output_dir}/logs",
694698
logging_steps=10,
@@ -697,6 +701,9 @@ def main(
697701
load_best_model_at_end=True,
698702
metric_for_best_model="f1",
699703
save_total_limit=2,
704+
# Additional stability measures
705+
dataloader_drop_last=False,
706+
eval_accumulation_steps=1,
700707
report_to=[],
701708
fp16=torch.cuda.is_available(),
702709
)
@@ -968,7 +975,7 @@ def demo_inference(
968975
lora_dropout=args.lora_dropout,
969976
num_epochs=args.epochs,
970977
batch_size=args.batch_size,
971-
learning_rate=args.learning_rate,
978+
learning_rate=3e-5, # Default optimized learning rate for LoRA token classification
972979
max_samples=args.max_samples,
973980
)
974981
elif args.mode == "test":

0 commit comments

Comments
 (0)