Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 54d895b

Browse files
committed
ensure 8B is default
1 parent 13bdcb3 commit 54d895b

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

dist_run.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
logger = SingletonLogger.get_logger()
5454

55-
MODEL_NAME = "Transformer-2-7b-chat-hf"
55+
MODEL_NAME = "Meta-Llama-3-8B"
5656

5757
NAME_TO_HF_MODEL_ID_AND_DTYPE = {
5858
"Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
@@ -216,7 +216,7 @@ def main():
216216

217217
# Load weights
218218
logger.info(f"Loading weights for {pp_rank=} on {device=}")
219-
with TrackTime("cuda") as timer:
219+
with TrackTime() as timer:
220220
_load_model_weights(model, hf_model_name, device=device, model_config=config)
221221
logger.info(
222222
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
@@ -261,7 +261,8 @@ def main():
261261
# create a padded tensor for the input prompt
262262
padded_sequence, prompt_len = _create_padded_prompt(input_ids, tokenizer, seqlen, start_pos, device)
263263
logger.info(f"{prompt_len=}")
264-
264+
logger.info(f"{padded_sequence[0, :prompt_len+1]=}")
265+
265266
schedule = ScheduleGPipe(stage, mbs)
266267
logger.info(f"Created schedule: {schedule}")
267268

0 commit comments

Comments
 (0)