Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 221ea94

Browse files
committed
remove logging, update formatting for display
1 parent 511b5b8 commit 221ea94

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

dist_run.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _update_padded_sequence(
223223
for i in range(len(prompt_lengths)):
224224
prompt_lengths[i] += 1
225225
padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0]
226-
logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
226+
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
227227

228228

229229
def _cleanup():
@@ -296,7 +296,7 @@ def main(args):
296296
logger.info(f"Model: {model}")
297297

298298
mbs = 1 # number of micro-batches
299-
mb_size = 2 # micro-batch size
299+
mb_size = 5 # micro-batch size
300300
batch_size = mbs * mb_size # total batch size
301301

302302
seqlen = 4096 # sequence length
@@ -345,6 +345,9 @@ def main(args):
345345
prompt = [
346346
"What is snow?",
347347
"Where does Santa Claus live?",
348+
"What is PyTorch?",
349+
"Write a poem about the beauty of the night sky.",
350+
"What is the capital of France, Germany and Switzerland?",
348351
]
349352

350353
"""
@@ -379,13 +382,11 @@ def main(args):
379382
input_ids = _encode_strings(
380383
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
381384
)
382-
logger.info(f"{input_ids[0][0:8]=}")
383385

384386
# create a padded tensor for the input prompt
385387
padded_sequence, prompt_lengths = _create_padded_prompts(
386388
input_ids, tokenizer, seqlen, start_pos, device
387389
)
388-
logger.info(f"length of each prompt in the batch: {prompt_lengths=}")
389390

390391
# create schedule
391392
schedule = ScheduleGPipe(stage, mbs)
@@ -397,7 +398,7 @@ def main(args):
397398
# need a new token dimension (row) for each prompt in the batch
398399
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
399400
res = [[] for _ in range(total_prompts)]
400-
num_tokens = 20
401+
num_tokens = 40
401402

402403
# Decoding
403404
with torch.no_grad():
@@ -449,18 +450,17 @@ def main(args):
449450
# Update input sequence with new token
450451
if pp_rank == first_pp_rank:
451452
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)
452-
for i in range(len(prompt_lengths)):
453-
logger.info(
454-
f"next submission: {padded_sequence[i, prompt_lengths[i]-4:prompt_lengths[i]+4]}"
455-
)
453+
454+
# Display the decoding results
456455

457456
# output formatted response via last pp group and tp rank 0
458457
if pp_rank == last_pp_rank and tp_rank == 0:
459458
for i in range(len(prompt_lengths)):
460-
logger.info(f"Prompt:{color.green} {prompt[i]} {color.reset}")
459+
logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}")
461460
formatted_response = "".join(res[i])
462-
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$")
461+
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n")
463462

463+
# Cleanup
464464
logger.info(
465465
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
466466
)

0 commit comments

Comments
 (0)