1616 causal language model!
1717
1818Usage:
19- # Train with recommended parameters
20- python ft_qwen3_generative_lora.py --mode train --epochs 8 --lora-rank 16 --max-samples 2000
19+ # Train with recommended parameters (150 samples per category = ~2100 total)
20+ python ft_qwen3_generative_lora.py --mode train --epochs 8 --lora-rank 16 --max-samples-per-category 150
2121
2222 # Test with specific GPU
2323 python ft_qwen3_generative_lora.py --mode train --epochs 8 --gpu-id 2
2424
2525 # Adjust batch size based on GPU memory (default: 4)
2626 python ft_qwen3_generative_lora.py --mode train --batch-size 8 --epochs 5
2727
28- # Quick test
29- python ft_qwen3_generative_lora.py --mode train --epochs 1 --max-samples 50
28+ # Quick test (10 samples per category = ~140 total)
29+ python ft_qwen3_generative_lora.py --mode train --epochs 1 --max-samples-per-category 10
3030
3131 # Inference
3232 python ft_qwen3_generative_lora.py --mode test --model-path qwen3_generative_classifier
@@ -147,8 +147,13 @@ def __init__(self, dataset_name="TIGER-Lab/MMLU-Pro"):
147147 self .label2id = {}
148148 self .id2label = {}
149149
150- def load_huggingface_dataset (self , max_samples = 1000 ):
151- """Load the MMLU-Pro dataset from HuggingFace with balanced sampling."""
150+ def load_huggingface_dataset (self , max_samples_per_category = 150 ):
151+ """Load the MMLU-Pro dataset from HuggingFace with balanced sampling.
152+
153+ Args:
154+ max_samples_per_category: Maximum number of samples to take from each category.
155+ Default: 150 per category (14 categories = ~2100 total)
156+ """
152157 logger .info (f"Loading dataset from HuggingFace: { self .dataset_name } " )
153158
154159 try :
@@ -169,14 +174,12 @@ def load_huggingface_dataset(self, max_samples=1000):
169174
170175 logger .info (f"Available categories: { sorted (category_samples .keys ())} " )
171176
172- # Calculate samples per category for balanced sampling
177+ # Use samples per category directly
173178 available_required_categories = [
174179 cat for cat in REQUIRED_CATEGORIES if cat in category_samples
175180 ]
176181
177- target_samples_per_category = max_samples // len (
178- available_required_categories
179- )
182+ target_samples_per_category = max_samples_per_category
180183
181184 # Collect balanced samples
182185 filtered_texts = []
@@ -202,9 +205,13 @@ def load_huggingface_dataset(self, max_samples=1000):
202205 logger .error (f"Error loading dataset: { e } " )
203206 raise
204207
205- def prepare_datasets (self , max_samples = 1000 ):
206- """Prepare train/validation/test datasets."""
207- texts , labels = self .load_huggingface_dataset (max_samples )
208+ def prepare_datasets (self , max_samples_per_category = 150 ):
209+ """Prepare train/validation/test datasets.
210+
211+ Args:
212+ max_samples_per_category: Maximum samples per category (default: 150)
213+ """
214+ texts , labels = self .load_huggingface_dataset (max_samples_per_category )
208215
209216 # Create label mapping
210217 unique_labels = sorted (list (set (labels )))
@@ -341,11 +348,16 @@ def main(
341348 num_epochs : int = 8 , # More epochs for 0.6B
342349 batch_size : int = 4 , # Configurable batch size (adjust based on GPU memory)
343350 learning_rate : float = 3e-4 , # Higher LR for small model
344- max_samples : int = 2000 ,
351+ max_samples_per_category : int = 150 , # Samples per category for balanced dataset
345352 output_dir : str = None ,
346353 gpu_id : Optional [int ] = None ,
347354):
348- """Main training function for generative Qwen3 classification."""
355+ """Main training function for generative Qwen3 classification.
356+
357+ Args:
358+ max_samples_per_category: Maximum samples per category (default: 150).
359+ With 14 categories, this gives ~2100 total samples.
360+ """
349361 logger .info ("Starting Qwen3 Generative Classification Fine-tuning" )
350362 logger .info ("Training Qwen3 to GENERATE category labels (instruction-following)" )
351363
@@ -360,7 +372,7 @@ def main(
360372
361373 # Load dataset
362374 dataset_loader = MMLU_Dataset ()
363- datasets = dataset_loader .prepare_datasets (max_samples )
375+ datasets = dataset_loader .prepare_datasets (max_samples_per_category )
364376
365377 train_texts , train_labels = datasets ["train" ]
366378 val_texts , val_labels = datasets ["validation" ]
@@ -706,7 +718,10 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
706718 "--learning-rate" , type = float , default = 3e-4 , help = "Learning rate"
707719 )
708720 parser .add_argument (
709- "--max-samples" , type = int , default = 2000 , help = "Maximum training samples"
721+ "--max-samples-per-category" ,
722+ type = int ,
723+ default = 150 ,
724+ help = "Maximum samples per category for balanced training (default: 150 per category = ~2100 total with 14 categories)" ,
710725 )
711726 parser .add_argument ("--output-dir" , type = str , default = None )
712727 parser .add_argument ("--gpu-id" , type = int , default = None )
@@ -731,7 +746,7 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
731746 num_epochs = args .epochs ,
732747 batch_size = args .batch_size ,
733748 learning_rate = args .learning_rate ,
734- max_samples = args .max_samples ,
749+ max_samples_per_category = args .max_samples_per_category ,
735750 output_dir = args .output_dir ,
736751 gpu_id = args .gpu_id ,
737752 )
0 commit comments