|
73 | 73 | clear_gpu_memory, |
74 | 74 | get_device_info, |
75 | 75 | log_memory_usage, |
| 76 | + set_gpu_device, |
76 | 77 | setup_logging, |
77 | 78 | ) |
78 | 79 |
|
@@ -340,10 +341,11 @@ def main( |
340 | 341 | logger.info("Starting Qwen3 Generative Classification Fine-tuning") |
341 | 342 | logger.info("Training Qwen3 to GENERATE category labels (instruction-following)") |
342 | 343 |
|
343 | | - # GPU selection |
344 | | - if gpu_id is not None: |
345 | | - os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) |
346 | | - logger.info(f"Set CUDA_VISIBLE_DEVICES={gpu_id}") |
| 344 | + # GPU selection using utility function |
| 345 | + device_str, selected_gpu = set_gpu_device( |
| 346 | + gpu_id=gpu_id, auto_select=(gpu_id is None) |
| 347 | + ) |
| 348 | + logger.info(f"Using device: {device_str} (GPU {selected_gpu})") |
347 | 349 |
|
348 | 350 | clear_gpu_memory() |
349 | 351 | log_memory_usage("Pre-training") |
@@ -375,9 +377,8 @@ def main( |
375 | 377 | low_cpu_mem_usage=True, |
376 | 378 | ) |
377 | 379 |
|
378 | | - # Move to GPU |
379 | | - device = "cuda" if torch.cuda.is_available() else "cpu" |
380 | | - model = model.to(device) |
| 380 | + # Move to GPU using device from set_gpu_device utility |
| 381 | + model = model.to(device_str) |
381 | 382 |
|
382 | 383 | # Prepare model for training |
383 | 384 | model.config.use_cache = False # Required for training |
@@ -428,9 +429,9 @@ def main( |
428 | 429 | training_args = TrainingArguments( |
429 | 430 | output_dir=output_dir, |
430 | 431 | num_train_epochs=num_epochs, |
431 | | - per_device_train_batch_size=1, # Minimal batch size to fit in memory |
432 | | - per_device_eval_batch_size=1, |
433 | | - gradient_accumulation_steps=16, # Effective batch size = 1 * 16 = 16 |
| 432 | + per_device_train_batch_size=4, # Increased batch size for better GPU utilization |
| 433 | + per_device_eval_batch_size=4, |
| 434 | + gradient_accumulation_steps=4, # Effective batch size = 4 * 4 = 16 |
434 | 435 | learning_rate=learning_rate, |
435 | 436 | weight_decay=0.01, |
436 | 437 | logging_dir=f"{output_dir}/logs", |
@@ -684,10 +685,8 @@ def demo_inference(model_path: str, model_name: str = "Qwen/Qwen3-0.6B"): |
684 | 685 |
|
685 | 686 | args = parser.parse_args() |
686 | 687 |
|
687 | | - # Set CUDA device early |
688 | | - if args.gpu_id is not None: |
689 | | - os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) |
690 | | - print(f"INFO: Set CUDA_VISIBLE_DEVICES={args.gpu_id}") |
| 688 | + # GPU device selection is handled in main() and demo_inference() functions |
| 689 | + # using the set_gpu_device() utility function for consistency |
691 | 690 |
|
692 | 691 | if args.mode == "train": |
693 | 692 | main( |
|
0 commit comments