@@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
209209 batch_size , seq_len , vocab_size = output .shape
210210
211211 if step != - 1 :
212+ # `pos` is not provided, so we can use the first token
212213 next_token_logits = output [:, 0 , :]
213214 else :
214215 # get the logits for each prompt at the specified positions
@@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
228229 ).squeeze (- 1 )
229230 else :
230231 # Argmax (deterministic)
231- next_tokens = torch .argmax (next_token_logits , dim = - 1 )
232+ next_tokens = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
232233
233- logger . info ( f" { color . yellow } Next tokens: { color . blue } { next_tokens } { color . reset } " )
234+ # Token ids in int tensor form
234235 return next_tokens
235236
236237
@@ -247,6 +248,11 @@ def _update_padded_sequence(
247248# Decode token id into string and print it
248249def _decode_in_flight (token , tokenizer , tp_rank ):
249250 """decode token ids for all prompts in the batch and log them"""
251+ # `token` is a tensor of shape (batch_size, 1).
252+ # For TiktokenTokenizer, we need to squeeze it to 1D.
253+ # For SentencePieceProcessor, we don't.
254+ if isinstance (tokenizer , TiktokenTokenizer ):
255+ token = torch .squeeze (token , dim = 1 )
250256 token_str = tokenizer .decode (token .tolist ())
251257 # print the token string on tp rank 0
252258 if tp_rank == 0 :
@@ -328,15 +334,26 @@ def main(args):
328334 config .stage_idx = pp_rank
329335 config .n_stages = pp_degree
330336
331- with device :
337+ with torch . device ( "meta" ) :
332338 # TODO: we should create model instead of Transformer
333339 model = Transformer (config )
334340
335341 # Distribute model on TP mesh
342+ # (Surprisingly, this works even though model is on meta device and mesh is of
343+ # cuda devices)
336344 model .distribute (tp_mesh )
337345 if rank == 0 :
338346 logger .info (f"Model: { model } " )
339347
348+ # Load weights
349+ logger .info (f"Loading weights for { pp_rank = } on { device = } " )
350+ with CUDATrackTime () as timer :
351+ _load_model_weights (model , distribution , device = device , model_config = config )
352+
353+ logger .info (
354+ f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
355+ )
356+
340357 # Batch size. Since we push batches dynamically through the pipeline rather
341358 # than chunking them, this is effectively micro-batch size in pipeline
342359 # sense. Thus it is interchangeable with micro-batch size below.
@@ -352,17 +369,8 @@ def main(args):
352369 # lanes.
353370 # TODO: bump up the lane count
354371 pipeline_lanes = 1
355- model .setup_caches (batch_size , seqlen_prefill , cache_lanes = pipeline_lanes )
356-
357- # Load weights
358- logger .info (f"Loading weights for { pp_rank = } on { device = } " )
359- with CUDATrackTime () as timer :
360- _load_model_weights (model , distribution , device = device , model_config = config )
361- model .to (device )
362-
363- logger .info (
364- f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
365- )
372+ with device :
373+ model .setup_caches (batch_size , seqlen_prefill , cache_lanes = pipeline_lanes )
366374
367375 # info on stage size and params
368376 stage_size = get_module_size (model )
@@ -528,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
528536
529537 # output formatted response via last pp group and tp rank 0
530538 if pp_rank == last_pp_rank and tp_rank == 0 :
531- # `res` is a list of tensors, each being a batch of generated token ids
532-
533- res_stacked = torch .stack (res , dim = 1 )
534- res_list = res_stacked .tolist ()
535-
536- # Decode the output as comprehension instead of loop
537- responses = [tokenizer .decode (sequence ) for sequence in res_list ]
538-
539+ # `res` is a list of tensors, each being a batch of generated token ids.
540+ # We need to concatenate them to get the full sequence of generated
541+ # token ids. Thus cat'ing along dim 1.
542+ res = torch .cat (res , dim = 1 )
543+ res_list = res .tolist ()
544+ responses = tokenizer .decode (res_list )
539545 # Show prompts and responses
540546 for prompt_text , response_text in zip (prompt , responses ):
541547 logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
0 commit comments