@@ -1683,11 +1683,27 @@ def apply_chat_template(example):
16831683
16841684 # Delete trainer and model to free GPU memory for evaluation
16851685 logger .info ("🧹 Cleaning up training resources to free GPU memory..." )
1686- if "trainer" in locals () :
1686+ try :
16871687 del trainer
1688- if "model" in locals ():
1688+ logger .info (" ✓ Trainer deleted" )
1689+ except :
1690+ pass
1691+ try :
16891692 del model
1693+ logger .info (" ✓ Model deleted" )
1694+ except :
1695+ pass
1696+
1697+ # Force garbage collection and GPU memory cleanup
1698+ import gc
1699+
1700+ gc .collect ()
16901701 clear_gpu_memory ()
1702+
1703+ # Give CUDA a moment to release memory
1704+ import time
1705+
1706+ time .sleep (2 )
16911707 logger .info ("✓ GPU memory cleared for evaluation\n " )
16921708 else :
16931709 logger .info (
@@ -1703,7 +1719,9 @@ def apply_chat_template(example):
17031719 model_name ,
17041720 trust_remote_code = True ,
17051721 low_cpu_mem_usage = True ,
1706- ).to (eval_device )
1722+ torch_dtype = torch .bfloat16 , # Load in BF16 to save memory
1723+ device_map = eval_device , # Directly load to device instead of .to()
1724+ )
17071725 base_model_for_baseline .eval ()
17081726
17091727 logger .info ("\n " + "🔍" * 40 )
@@ -1738,11 +1756,12 @@ def apply_chat_template(example):
17381756 model_name ,
17391757 trust_remote_code = True ,
17401758 low_cpu_mem_usage = True ,
1759+ torch_dtype = torch .bfloat16 , # Load in BF16 to save memory
1760+ device_map = eval_device , # Directly load to device
17411761 )
17421762 from peft import PeftModel
17431763
17441764 eval_model = PeftModel .from_pretrained (eval_base_model , output_dir )
1745- eval_model = eval_model .to (eval_device )
17461765 eval_model .eval ()
17471766
17481767 post_training_results = evaluate_model_on_mmlu_pro (
0 commit comments