Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
import shutil
import sys
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

import torch
import torch.nn as nn
Expand All @@ -90,9 +90,12 @@
from common_lora_utils import (
clear_gpu_memory,
create_lora_config,
find_free_gpu,
get_all_gpu_info,
get_device_info,
log_memory_usage,
resolve_model_path,
set_gpu_device,
setup_logging,
validate_lora_config,
)
Expand Down Expand Up @@ -449,12 +452,31 @@ def main(
output_dir: str = None,
enable_feature_alignment: bool = False,
alignment_weight: float = 0.1,
gpu_id: int = None,
):
"""Main training function for LoRA intent classification."""
logger.info("Starting Enhanced LoRA Intent Classification Training")

# Device configuration and memory management
device, device_info = get_device_info()
# GPU selection and device configuration
if gpu_id is not None:
logger.info(f"Using specified GPU: {gpu_id}")
device_str, selected_gpu = set_gpu_device(gpu_id=gpu_id, auto_select=False)
else:
logger.info("Auto-selecting best available GPU...")
device_str, selected_gpu = set_gpu_device(gpu_id=None, auto_select=True)

# Log all GPU info
all_gpus = get_all_gpu_info()
if all_gpus:
logger.info(f"Available GPUs: {len(all_gpus)}")
for gpu in all_gpus:
status = "SELECTED" if gpu["id"] == selected_gpu else "available"
logger.info(
f" GPU {gpu['id']} ({status}): {gpu['name']} - "
f"{gpu['free_memory_gb']:.2f}GB free / {gpu['total_memory_gb']:.2f}GB total"
)

# Clear memory on selected device
clear_gpu_memory()
log_memory_usage("Pre-training")

Expand Down Expand Up @@ -753,6 +775,12 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
default="lora_intent_classifier_modernbert-base_r8",
help="Path to saved model for inference (default: ../../../models/lora_intent_classifier_r8)",
)
parser.add_argument(
"--gpu-id",
type=int,
default=None,
help="Specific GPU ID to use (0-3 for 4 GPUs). If not specified, automatically selects GPU with most free memory",
)

args = parser.parse_args()

Expand All @@ -769,6 +797,7 @@ def demo_inference(model_path: str, model_name: str = "modernbert-base"):
enable_feature_alignment=args.enable_feature_alignment,
alignment_weight=args.alignment_weight,
output_dir=args.output_dir,
gpu_id=args.gpu_id,
)
elif args.mode == "test":
demo_inference(args.model_path, args.model)
Loading
Loading