1616
1717import torch
1818import torch .distributed as dist
19+ from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
20+ from torchchat .cli .builder import _initialize_tokenizer , TokenizerArgs
1921
2022from torchchat .distributed .logging_utils import SingletonLogger
2123
3335 get_num_params ,
3436 GPUMemoryMonitor ,
3537)
36- from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
37- from torchchat .cli .builder import _initialize_tokenizer , TokenizerArgs
3838from torchchat .model import ModelArgs , Transformer , TransformerArgs
3939from torchchat .utils .build_utils import set_precision
4040
@@ -189,23 +189,49 @@ def _create_padded_prompts(
189189
190190def _batch_decode_next_tokens (
191191 output : torch .Tensor ,
192- pos : int ,
192+ pos : List [int ],
193+ step : int = - 1 ,
194+ temperature : float = 1.0 ,
195+ topk : int = 10 ,
193196) -> torch .Tensor :
194197 """
195- Decode the next token for each prompt in the batch.
198+ Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
199+
196200 Args:
197201 output (torch.Tensor): The output tensor to decode.
198- pos: the position of the `output` to decode in the sequence length dimension.
202+ pos (List[int]): The positions of the `output` to decode in the sequence length dimension.
203+ step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token.
204+ temperature (float): Sampling temperature for non-deterministic decoding.
199205
200206 Returns:
201- Decoded token ids.
207+ torch.Tensor: Decoded token ids.
202208 """
203- # Take the next token logits for each prompt
204- next_token_logits = output [:, pos , :]
205- # Argmax (deterministic) TODO: add temperature
206- next_token = torch .argmax (next_token_logits , dim = - 1 )
207- # Token ids in int tensor form
208- return next_token
209+ batch_size , seq_len , vocab_size = output .shape
210+
211+ if step != - 1 :
212+ next_token_logits = output [:, 0 , :]
213+ else :
214+ # get the logits for each prompt at the specified positions
215+ next_token_logits = output [torch .arange (batch_size ), torch .tensor (pos ) - 1 ]
216+
217+ if temperature != 1.0 :
218+ next_token_logits = next_token_logits / temperature
219+
220+ # Uses top-k sampling if temperature is not 1.0, otherwise use argmax
221+ if temperature != 1.0 :
222+ top_k = min (topk , vocab_size ) # Ensure top-k is not greater than vocab size
223+ top_k_logits , top_k_indices = torch .topk (next_token_logits , k = top_k , dim = - 1 )
224+ probs = torch .softmax (top_k_logits , dim = - 1 )
225+ next_token_indices = torch .multinomial (probs , num_samples = 1 ).squeeze (- 1 )
226+ next_tokens = top_k_indices .gather (
227+ - 1 , next_token_indices .unsqueeze (- 1 )
228+ ).squeeze (- 1 )
229+ else :
230+ # Argmax (deterministic)
231+ next_tokens = torch .argmax (next_token_logits , dim = - 1 )
232+
233+ logger .info (f"{ color .yellow } Next tokens: { color .blue } { next_tokens } { color .reset } " )
234+ return next_tokens
209235
210236
211237def _update_padded_sequence (
@@ -218,11 +244,32 @@ def _update_padded_sequence(
218244 # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
219245
220246
247+ # Decode token id into string and print it
248+ def _decode_in_flight (token , tokenizer , tp_rank ):
249+ """decode token ids for all prompts in the batch and log them"""
250+ token_str = tokenizer .decode (token .tolist ())
251+ # print the token string on tp rank 0
252+ if tp_rank == 0 :
253+ logger .info (
254+ f"{ color .green } responses ====>>>> "
255+ f"{ color .blue } { token_str } { color .reset } "
256+ )
257+
258+
221259def _cleanup ():
222260 dist .barrier ()
223261 dist .destroy_process_group ()
224262
225263
264+ prompt = [
265+ "What is Snow?" ,
266+ "Who is Santa Claus?" ,
267+ "Where does Santa live?" ,
268+ # "Who is Abraham Lincoln?",
269+ # "How are models trained?",
270+ ]
271+
272+
226273def main (args ):
227274 model_name = args .model_name
228275 pp_degree = args .pp
@@ -293,7 +340,7 @@ def main(args):
293340 # Batch size. Since we push batches dynamically through the pipeline rather
294341 # than chunking them, this is effectively micro-batch size in pipeline
295342 # sense. Thus it is interchangeable with micro-batch size below.
296- batch_size = 4
343+ batch_size = len ( prompt )
297344 seqlen_prefill = 1024 # sequence length
298345 dim = 4096 # embedding dimension
299346
@@ -331,7 +378,9 @@ def main(args):
331378
332379 # Helper function to get example inputs and outputs for the stages.
333380 def get_example_ins_outs (seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
334- mb_ids = torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
381+ mb_ids = torch .randint (
382+ 0 , config .vocab_size , (batch_size , seqlen ), device = device
383+ )
335384 activation = torch .rand (
336385 batch_size , seqlen , dim , device = device , dtype = model_dtype
337386 )
@@ -362,13 +411,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
362411 # pipelining effect.
363412 prefiller = ScheduleGPipe (prefill_stage , 1 )
364413
365- prompt = [
366- "What is a computer?" ,
367- "Where does Santa live?" ,
368- "Who is Abraham Lincoln?" ,
369- "How are models trained?" ,
370- ]
371-
372414 start_pos = 0
373415
374416 # Need these global ids due to the API definition of dist.send and recv
@@ -384,10 +426,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
384426 padded_sequence , prompt_lengths = _create_padded_prompts (
385427 input_ids , tokenizer , seqlen_prefill , start_pos , device
386428 )
387- # TODO: figure out how to set input_pos for each prompt in the batch then we
388- # can remove this limitation.
389- s = set (prompt_lengths )
390- assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
391429
392430 # Need these global ids due to the API definition of dist.send and recv
393431 first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
@@ -396,6 +434,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
396434 # New token generated each iteration
397435 # need a row dimension for each prompt in the batch
398436 new_token = torch .zeros (batch_size , 1 , device = device , dtype = torch .int64 )
437+ logger .info (f"{ color .green } { new_token .shape = } , { new_token = } { color .reset } " )
399438 # Store the generated tokens
400439 res = []
401440
@@ -416,23 +455,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
416455 f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
417456 )
418457
419- # Decode token id into string and print it
420- def decode_in_flight (token ):
421- # Make a 2D tensor with ids on row dimension
422- unsqueezed = torch .unsqueeze (token , 1 )
423- token_str = tokenizer .decode (unsqueezed .tolist ())
424- if tp_rank == 0 :
425- logger .info (
426- f"{ color .green } responses ====>>>> "
427- f"{ color .blue } { token_str } { color .reset } "
428- )
429-
430458 # Decode the output -- first generated token
431459 if pp_rank == last_pp_rank :
432- new_token = _batch_decode_next_tokens (output , prompt_lengths [0 ] - 1 )
460+ logger .info (f"{ color .green } Decoding...{ prompt_lengths = } { color .reset } " )
461+ new_token = _batch_decode_next_tokens (output , prompt_lengths )
433462 res .append (new_token )
434463 if not args .disable_in_flight_decode :
435- decode_in_flight (new_token )
464+ _decode_in_flight (new_token , tokenizer , tp_rank )
436465
437466 # seqlen = 1 now
438467 seqlen_decode = 1
@@ -482,10 +511,11 @@ def decode_in_flight(token):
482511
483512 # Decode the output
484513 if pp_rank == last_pp_rank :
485- new_token = _batch_decode_next_tokens (output , 0 )
514+ # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
515+ new_token = _batch_decode_next_tokens (output , prompt_lengths , step )
486516 res .append (new_token )
487517 if not args .disable_in_flight_decode :
488- decode_in_flight (new_token )
518+ _decode_in_flight (new_token , tokenizer , tp_rank )
489519
490520 # Increment input position
491521 input_pos += 1
@@ -499,12 +529,17 @@ def decode_in_flight(token):
499529 # output formatted response via last pp group and tp rank 0
500530 if pp_rank == last_pp_rank and tp_rank == 0 :
501531 # `res` is a list of tensors, each being a batch of generated token ids
502- res = torch .stack (res , dim = 1 )
503- res_list = res .tolist ()
504- response = tokenizer .decode (res_list )
505- for i in range (len (response )):
506- logger .info (f"Prompt: { color .green } { prompt [i ]} { color .reset } " )
507- logger .info (f"Response: { color .red } { response [i ]} { color .reset } " )
532+
533+ res_stacked = torch .stack (res , dim = 1 )
534+ res_list = res_stacked .tolist ()
535+
536+ # Decode the output as comprehension instead of loop
537+ responses = [tokenizer .decode (sequence ) for sequence in res_list ]
538+
539+ # Show prompts and responses
540+ for prompt_text , response_text in zip (prompt , responses ):
541+ logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
542+ logger .info (f"Response: { color .red } { response_text } { color .reset } " )
508543
509544 # Cleanup
510545 _cleanup ()
0 commit comments