Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 13bdcb3

Browse files
committed
update prefill functions
1 parent 28598e7 commit 13bdcb3

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

dist_run.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
from pathlib import Path
99
from types import SimpleNamespace
10-
from typing import Any, Dict
10+
from typing import Any, Dict, Tuple
1111

1212
# Run command:
1313
# torchrun --nproc-per-node 4 dist_run.py
@@ -124,19 +124,19 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config):
124124

125125
def _encode_string(string, tokenizer, bos=True, device="cuda", dtype=torch.int64)-> torch.Tensor:
126126
"""Encode a prompt string into a tensor of token ids."""
127-
tokens = tokenizer.encode(string)
128-
if bos:
129-
tokens = [tokenizer.bos_id()] + tokens
130-
return torch.tensor(tokens, dtype=dtype, device=device)
131-
132-
def _create_padded_prompt(input_ids, seqlen, start_pos, device) -> torch.Tensor:
133-
"""Create a padded tensor for the encoded input prompt."""
134-
prompt_len = input_ids.size(0)
135-
max_new_tokens = min(seqlen, seqlen - start_pos - prompt_len)
136-
token_buffer_size = prompt_len + max_new_tokens
137-
seq = torch.full((1, token_buffer_size), tokenizer.eos_id(), dtype=torch.int64, device=device)
138-
seq[0, :prompt_len] = input_ids
139-
return seq
127+
tokens = tokenizer.encode(string)
128+
if bos:
129+
tokens = [tokenizer.bos_id()] + tokens
130+
return torch.tensor(tokens, dtype=dtype, device=device)
131+
132+
def _create_padded_prompt(input_ids, tokenizer, seqlen, start_pos, device) -> Tuple[torch.Tensor, int]:
133+
"""Create a padded tensor for the encoded input prompt. Returns the padded tensor and the prompt length."""
134+
prompt_len = input_ids.size(0)
135+
max_new_tokens = min(seqlen, seqlen - start_pos - prompt_len)
136+
token_buffer_size = prompt_len + max_new_tokens
137+
seq = torch.full((1, token_buffer_size), tokenizer.eos_id(), dtype=torch.int64, device=device)
138+
seq[0, :prompt_len] = input_ids
139+
return seq, prompt_len
140140

141141
def _cleanup():
142142
dist.barrier()
@@ -251,32 +251,30 @@ def main():
251251
if len(cpu_tensors) > 0:
252252
raise ValueError("Found cpu tensors in stage")
253253

254-
255-
prompt = "What is the capital of France?"
254+
prompt = "What is snow?"
256255
start_pos = 0
257256

258257
# encode the prompt
259258
input_ids = _encode_string(prompt, tokenizer, bos=True, device=device, dtype=torch.int64)
260-
261-
# create a padded tensor for the input prompt
262-
seq = _create_padded_prompt(input_ids, seqlen, start_pos, device)
259+
logger.info(f"{input_ids[0:8]=}")
263260

261+
# create a padded tensor for the input prompt
262+
padded_sequence, prompt_len = _create_padded_prompt(input_ids, tokenizer, seqlen, start_pos, device)
263+
logger.info(f"{prompt_len=}")
264264

265265
schedule = ScheduleGPipe(stage, mbs)
266266
logger.info(f"Created schedule: {schedule}")
267267

268268
with torch.no_grad(): # .inference_mode():
269269
if pp_rank == 0:
270-
schedule.step(seq)
270+
schedule.step(padded_sequence)
271271
else:
272272
output = schedule.step()
273273

274274
# Decoding
275275
if pp_rank == pp_degree - 1 and tp_rank == 0:
276-
277276
next_token_logits = output[:,prompt_len-1, :]
278277
next_token = torch.argmax(next_token_logits, dim=-1)
279-
280278
next_token_decoded = tokenizer.decode((next_token.tolist()))
281279

282280
logger.info(f"\n\n{color.green} Prefill response ====>>>> {color.blue} {next_token_decoded=}, {next_token}\n{color.reset}")

0 commit comments

Comments
 (0)