Skip to content

Commit 06f717a

Browse files
committed
review feedback
Signed-off-by: Huamin Chen <[email protected]>
1 parent effbaa3 commit 06f717a

File tree

1 file changed

+33
-18
lines changed

1 file changed

+33
-18
lines changed

src/training/training_lora/classifier_model_fine_tuning_lora/ft_qwen3_generative_lora.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
causal language model!
1717
1818
Usage:
19-
# Train with recommended parameters
20-
python ft_qwen3_generative_lora.py --mode train --epochs 8 --lora-rank 16 --max-samples 2000
19+
# Train with recommended parameters (150 samples per category = ~2100 total)
20+
python ft_qwen3_generative_lora.py --mode train --epochs 8 --lora-rank 16 --max-samples-per-category 150
2121
2222
# Test with specific GPU
2323
python ft_qwen3_generative_lora.py --mode train --epochs 8 --gpu-id 2
2424
2525
# Adjust batch size based on GPU memory (default: 4)
2626
python ft_qwen3_generative_lora.py --mode train --batch-size 8 --epochs 5
2727
28-
# Quick test
29-
python ft_qwen3_generative_lora.py --mode train --epochs 1 --max-samples 50
28+
# Quick test (10 samples per category = ~140 total)
29+
python ft_qwen3_generative_lora.py --mode train --epochs 1 --max-samples-per-category 10
3030
3131
# Inference
3232
python ft_qwen3_generative_lora.py --mode test --model-path qwen3_generative_classifier
@@ -147,8 +147,13 @@ def __init__(self, dataset_name="TIGER-Lab/MMLU-Pro"):
147147
self.label2id = {}
148148
self.id2label = {}
149149

150-
def load_huggingface_dataset(self, max_samples=1000):
151-
"""Load the MMLU-Pro dataset from HuggingFace with balanced sampling."""
150+
def load_huggingface_dataset(self, max_samples_per_category=150):
151+
"""Load the MMLU-Pro dataset from HuggingFace with balanced sampling.
152+
153+
Args:
154+
max_samples_per_category: Maximum number of samples to take from each category.
155+
Default: 150 per category (14 categories = ~2100 total)
156+
"""
152157
logger.info(f"Loading dataset from HuggingFace: {self.dataset_name}")
153158

154159
try:
@@ -169,14 +174,12 @@ def load_huggingface_dataset(self, max_samples=1000):
169174

170175
logger.info(f"Available categories: {sorted(category_samples.keys())}")
171176

172-
# Calculate samples per category for balanced sampling
177+
# Use samples per category directly
173178
available_required_categories = [
174179
cat for cat in REQUIRED_CATEGORIES if cat in category_samples
175180
]
176181

177-
target_samples_per_category = max_samples // len(
178-
available_required_categories
179-
)
182+
target_samples_per_category = max_samples_per_category
180183

181184
# Collect balanced samples
182185
filtered_texts = []
@@ -202,9 +205,13 @@ def load_huggingface_dataset(self, max_samples=1000):
202205
logger.error(f"Error loading dataset: {e}")
203206
raise
204207

205-
def prepare_datasets(self, max_samples=1000):
206-
"""Prepare train/validation/test datasets."""
207-
texts, labels = self.load_huggingface_dataset(max_samples)
208+
def prepare_datasets(self, max_samples_per_category=150):
209+
"""Prepare train/validation/test datasets.
210+
211+
Args:
212+
max_samples_per_category: Maximum samples per category (default: 150)
213+
"""
214+
texts, labels = self.load_huggingface_dataset(max_samples_per_category)
208215

209216
# Create label mapping
210217
unique_labels = sorted(list(set(labels)))
@@ -341,11 +348,16 @@ def main(
341348
num_epochs: int = 8, # More epochs for 0.6B
342349
batch_size: int = 4, # Configurable batch size (adjust based on GPU memory)
343350
learning_rate: float = 3e-4, # Higher LR for small model
344-
max_samples: int = 2000,
351+
max_samples_per_category: int = 150, # Samples per category for balanced dataset
345352
output_dir: str = None,
346353
gpu_id: Optional[int] = None,
347354
):
348-
"""Main training function for generative Qwen3 classification."""
355+
"""Main training function for generative Qwen3 classification.
356+
357+
Args:
358+
max_samples_per_category: Maximum samples per category (default: 150).
359+
With 14 categories, this gives ~2100 total samples.
360+
"""
349361
logger.info("Starting Qwen3 Generative Classification Fine-tuning")
350362
logger.info("Training Qwen3 to GENERATE category labels (instruction-following)")
351363

@@ -360,7 +372,7 @@ def main(
360372

361373
# Load dataset
362374
dataset_loader = MMLU_Dataset()
363-
datasets = dataset_loader.prepare_datasets(max_samples)
375+
datasets = dataset_loader.prepare_datasets(max_samples_per_category)
364376

365377
train_texts, train_labels = datasets["train"]
366378
val_texts, val_labels = datasets["validation"]
@@ -706,7 +718,10 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
706718
"--learning-rate", type=float, default=3e-4, help="Learning rate"
707719
)
708720
parser.add_argument(
709-
"--max-samples", type=int, default=2000, help="Maximum training samples"
721+
"--max-samples-per-category",
722+
type=int,
723+
default=150,
724+
help="Maximum samples per category for balanced training (default: 150 per category = ~2100 total with 14 categories)",
710725
)
711726
parser.add_argument("--output-dir", type=str, default=None)
712727
parser.add_argument("--gpu-id", type=int, default=None)
@@ -731,7 +746,7 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
731746
num_epochs=args.epochs,
732747
batch_size=args.batch_size,
733748
learning_rate=args.learning_rate,
734-
max_samples=args.max_samples,
749+
max_samples_per_category=args.max_samples_per_category,
735750
output_dir=args.output_dir,
736751
gpu_id=args.gpu_id,
737752
)

0 commit comments

Comments
 (0)