Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 315a023

Browse files
committed
move prompt outside of main, auto-update batch size based on prompt
1 parent 41d61a8 commit 315a023

File tree

1 file changed

+10
-46
lines changed

1 file changed

+10
-46
lines changed

dist_run.py

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -188,43 +188,6 @@ def _create_padded_prompts(
188188
return batch_seq, prompt_lengths
189189

190190

191-
def _batch_decode_next_tokens_new(
192-
output: torch.Tensor, pos: torch.Tensor, step: int = -1, temperature: float = 1.0
193-
) -> torch.Tensor:
194-
"""
195-
Decode the next token for each prompt in the batch.
196-
197-
Args:
198-
output (torch.Tensor): The output tensor to decode. Shape: (batch_size, seq_len, vocab_size)
199-
pos (torch.Tensor): The current position for each sequence in the batch. Shape: (batch_size,)
200-
step (int): If not -1, use this fixed position for all sequences. Default: -1
201-
temperature (float): Sampling temperature. Higher values increase randomness. Default: 1.0
202-
203-
Returns:
204-
torch.Tensor: Decoded token ids. Shape: (batch_size,)
205-
"""
206-
batch_size, seq_len, vocab_size = output.shape
207-
208-
# Determine the token position to use for each sequence
209-
token_pos = torch.full_like(pos, step) if step != -1 else pos - 1
210-
211-
# Extract the relevant logits for each sequence
212-
next_token_logits = output[torch.arange(batch_size), token_pos]
213-
214-
if temperature != 1.0:
215-
next_token_logits = next_token_logits / temperature
216-
217-
# Sample from the distribution (or use argmax if temperature is very low)
218-
if temperature < 1e-5:
219-
next_tokens = torch.argmax(next_token_logits, dim=-1)
220-
else:
221-
next_tokens = torch.multinomial(
222-
F.softmax(next_token_logits, dim=-1), num_samples=1
223-
).squeeze(-1)
224-
225-
return next_tokens
226-
227-
228191
def _batch_decode_next_tokens(
229192
output: torch.Tensor, pos: List[int], step: int = -1, temperature: float = 1.0
230193
) -> torch.Tensor:
@@ -297,6 +260,15 @@ def _cleanup():
297260
dist.destroy_process_group()
298261

299262

263+
prompt = [
264+
"What is Snow?",
265+
"Who is Santa Claus?",
266+
"Where does Santa live?",
267+
# "Who is Abraham Lincoln?",
268+
# "How are models trained?",
269+
]
270+
271+
300272
def main(args):
301273
model_name = args.model_name
302274
pp_degree = args.pp
@@ -367,7 +339,7 @@ def main(args):
367339
# Batch size. Since we push batches dynamically through the pipeline rather
368340
# than chunking them, this is effectively micro-batch size in pipeline
369341
# sense. Thus it is interchangeable with micro-batch size below.
370-
batch_size = 3
342+
batch_size = len(prompt)
371343
seqlen_prefill = 1024 # sequence length
372344
dim = 4096 # embedding dimension
373345

@@ -438,14 +410,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
438410
# pipelining effect.
439411
prefiller = ScheduleGPipe(prefill_stage, 1)
440412

441-
prompt = [
442-
"What is Snow?",
443-
"Who is Santa Claus?",
444-
"Where does Santa live?",
445-
# "Who is Abraham Lincoln?",
446-
# "How are models trained?",
447-
]
448-
449413
start_pos = 0
450414

451415
# Need these global ids due to the API definition of dist.send and recv

0 commit comments

Comments
 (0)