Skip to content

Commit 55b0ebd

Browse files
committed
fix train to test transition issue
Signed-off-by: Huamin Chen <[email protected]>
1 parent 14bbc70 commit 55b0ebd

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

src/training/training_lora/mmlu_pro_solver_lora/ft_qwen3_mmlu_solver_lora_no_leakage.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,11 +1683,27 @@ def apply_chat_template(example):
16831683

16841684
# Delete trainer and model to free GPU memory for evaluation
16851685
logger.info("🧹 Cleaning up training resources to free GPU memory...")
1686-
if "trainer" in locals():
1686+
try:
16871687
del trainer
1688-
if "model" in locals():
1688+
logger.info(" ✓ Trainer deleted")
1689+
except:
1690+
pass
1691+
try:
16891692
del model
1693+
logger.info(" ✓ Model deleted")
1694+
except:
1695+
pass
1696+
1697+
# Force garbage collection and GPU memory cleanup
1698+
import gc
1699+
1700+
gc.collect()
16901701
clear_gpu_memory()
1702+
1703+
# Give CUDA a moment to release memory
1704+
import time
1705+
1706+
time.sleep(2)
16911707
logger.info("✓ GPU memory cleared for evaluation\n")
16921708
else:
16931709
logger.info(
@@ -1703,7 +1719,9 @@ def apply_chat_template(example):
17031719
model_name,
17041720
trust_remote_code=True,
17051721
low_cpu_mem_usage=True,
1706-
).to(eval_device)
1722+
torch_dtype=torch.bfloat16, # Load in BF16 to save memory
1723+
device_map=eval_device, # Directly load to device instead of .to()
1724+
)
17071725
base_model_for_baseline.eval()
17081726

17091727
logger.info("\n" + "🔍" * 40)
@@ -1738,11 +1756,12 @@ def apply_chat_template(example):
17381756
model_name,
17391757
trust_remote_code=True,
17401758
low_cpu_mem_usage=True,
1759+
torch_dtype=torch.bfloat16, # Load in BF16 to save memory
1760+
device_map=eval_device, # Directly load to device
17411761
)
17421762
from peft import PeftModel
17431763

17441764
eval_model = PeftModel.from_pretrained(eval_base_model, output_dir)
1745-
eval_model = eval_model.to(eval_device)
17461765
eval_model.eval()
17471766

17481767
post_training_results = evaluate_model_on_mmlu_pro(

0 commit comments

Comments
 (0)