@@ -188,43 +188,6 @@ def _create_padded_prompts(
188188 return batch_seq , prompt_lengths
189189
190190
191- def _batch_decode_next_tokens_new (
192- output : torch .Tensor , pos : torch .Tensor , step : int = - 1 , temperature : float = 1.0
193- ) -> torch .Tensor :
194- """
195- Decode the next token for each prompt in the batch.
196-
197- Args:
198- output (torch.Tensor): The output tensor to decode. Shape: (batch_size, seq_len, vocab_size)
199- pos (torch.Tensor): The current position for each sequence in the batch. Shape: (batch_size,)
200- step (int): If not -1, use this fixed position for all sequences. Default: -1
201- temperature (float): Sampling temperature. Higher values increase randomness. Default: 1.0
202-
203- Returns:
204- torch.Tensor: Decoded token ids. Shape: (batch_size,)
205- """
206- batch_size , seq_len , vocab_size = output .shape
207-
208- # Determine the token position to use for each sequence
209- token_pos = torch .full_like (pos , step ) if step != - 1 else pos - 1
210-
211- # Extract the relevant logits for each sequence
212- next_token_logits = output [torch .arange (batch_size ), token_pos ]
213-
214- if temperature != 1.0 :
215- next_token_logits = next_token_logits / temperature
216-
217- # Sample from the distribution (or use argmax if temperature is very low)
218- if temperature < 1e-5 :
219- next_tokens = torch .argmax (next_token_logits , dim = - 1 )
220- else :
221- next_tokens = torch .multinomial (
222- F .softmax (next_token_logits , dim = - 1 ), num_samples = 1
223- ).squeeze (- 1 )
224-
225- return next_tokens
226-
227-
228191def _batch_decode_next_tokens (
229192 output : torch .Tensor , pos : List [int ], step : int = - 1 , temperature : float = 1.0
230193) -> torch .Tensor :
@@ -297,6 +260,15 @@ def _cleanup():
297260 dist .destroy_process_group ()
298261
299262
263+ prompt = [
264+ "What is Snow?" ,
265+ "Who is Santa Claus?" ,
266+ "Where does Santa live?" ,
267+ # "Who is Abraham Lincoln?",
268+ # "How are models trained?",
269+ ]
270+
271+
300272def main (args ):
301273 model_name = args .model_name
302274 pp_degree = args .pp
@@ -367,7 +339,7 @@ def main(args):
367339 # Batch size. Since we push batches dynamically through the pipeline rather
368340 # than chunking them, this is effectively micro-batch size in pipeline
369341 # sense. Thus it is interchangeable with micro-batch size below.
370- batch_size = 3
342+ batch_size = len ( prompt )
371343 seqlen_prefill = 1024 # sequence length
372344 dim = 4096 # embedding dimension
373345
@@ -438,14 +410,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
438410 # pipelining effect.
439411 prefiller = ScheduleGPipe (prefill_stage , 1 )
440412
441- prompt = [
442- "What is Snow?" ,
443- "Who is Santa Claus?" ,
444- "Where does Santa live?" ,
445- # "Who is Abraham Lincoln?",
446- # "How are models trained?",
447- ]
448-
449413 start_pos = 0
450414
451415 # Need these global ids due to the API definition of dist.send and recv
0 commit comments