Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ab9a24d

Browse files
committed
add refined output, update force_download to 3-8B
1 parent 2bf85bf commit ab9a24d

File tree

2 files changed

+9
-29
lines changed

2 files changed

+9
-29
lines changed

dist_run.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -494,32 +494,12 @@ def main():
494494
else:
495495
schedule.step()
496496

497-
# logger.info(f"REVIEW {padded_sequence[0,:15]=}")
497+
# output formatted response via last pp group and tp rank 0
498+
if pp_rank == last_pp_group and tp_rank == 0:
499+
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
500+
formatted_response = "".join(res)
501+
logger.info(f"$$$$$$ {color.blue}{formatted_response}{color.reset} $$$$$")
498502

499-
# logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
500-
501-
# Decoding
502-
"""
503-
if pp_rank == pp_degree - 1 and tp_rank == 0:
504-
decode_results = _batch_decode_next_tokens(
505-
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
506-
)
507-
508-
logger.info(
509-
f"\n\n{color.green} Prefill responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
510-
)
511-
"""
512-
"""
513-
# show peak memory stats for this stage
514-
res_mem_gib, res_mem_pct = gpu_memory_monitor.get_peak_stats()
515-
logger.info(
516-
f"{color.blue} Memory used: {color.green}{res_mem_pct:.3f} %, {color.magenta}{res_mem_gib:.3f} GB{color.reset}"
517-
)
518-
519-
logger.info(
520-
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
521-
)
522-
"""
523503
logger.info(f"$$$$$$ {color.red}{res=}{color.reset} $$$$$")
524504
logger.info(
525505
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"

distributed/force_download.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from transformers import AutoTokenizer, AutoModelForCausalLM
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
22

3-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
4-
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
5-
print("Model weights downloaded")
3+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
4+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
5+
print("Model weights and tokenizer downloaded")

0 commit comments

Comments
 (0)