@@ -135,8 +135,8 @@ def _load_model_weights(stage_module, distribution, device, model_config):
135135def _encode_strings (
136136 strings : List [str ],
137137 tokenizer ,
138- bos : bool = True ,
139- device : torch .device = "cuda:0" ,
138+ bos : bool ,
139+ device : torch .device ,
140140 dtype = torch .int64 ,
141141) -> List [torch .Tensor ]:
142142 """Encode a list of prompt strings into a list of tensor token ids."""
@@ -216,13 +216,13 @@ def _batch_decode_next_tokens(
216216
217217def _update_padded_sequence (
218218 padded_sequence : torch .Tensor ,
219- x_recv : torch .Tensor ,
220- res ,
219+ new_token : torch .Tensor ,
221220 prompt_lengths : List [int ],
222221) -> None :
222+ # TODO: this is a hacky way to update the padded sequence: when there is
223+ # more than one prompt, the for loop and the assignment is incompatible.
223224 for i in range (len (prompt_lengths )):
224- prompt_lengths [i ] += 1
225- padded_sequence [i , prompt_lengths [i ] - 1 ] = x_recv
225+ padded_sequence [i , prompt_lengths [i ]] = new_token
226226
227227
228228def _cleanup ():
@@ -267,19 +267,15 @@ def main(args):
267267 device_mesh = _create_device_mesh (mesh_dimensions )
268268 tp_mesh = device_mesh ["tp" ]
269269 pp_mesh = device_mesh ["pp" ]
270+ logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
271+
270272 tp_rank = tp_mesh .get_local_rank ()
271273 pp_rank = pp_mesh .get_local_rank ()
272274 tp_group = tp_mesh .get_group ()
273275 pp_group = pp_mesh .get_group ()
274-
275- logger .info (f"review: { pp_group = } , { tp_group = } " )
276-
277- logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } \n " )
278- # TODO - this assumes 1D mesh, need to update for 2D+ mesh
279- pp_group_size = pp_mesh .size ()
280- tp_group_size = tp_mesh .size ()
281-
282- logger .info (f"pp_group_size: { pp_group_size } , tp_group_size: { tp_group_size } " )
276+ pp_group_size = pp_group .size ()
277+ tp_group_size = tp_group .size ()
278+ logger .info (f"{ pp_group_size = } , { tp_group_size = } " )
283279
284280 # Assuming same number of GPUs per node
285281 device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
@@ -316,7 +312,7 @@ def main(args):
316312 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
317313
318314 with CUDATrackTime () as timer :
319- _load_model_weights (model , hf_model_name , device = device , model_config = config )
315+ _load_model_weights (model , distribution , device = device , model_config = config )
320316
321317 logger .info (
322318 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
@@ -327,7 +323,7 @@ def main(args):
327323 stage_size_formatted = bytes_to_readable (stage_size )
328324 stage_num_params = get_num_params (model )
329325 logger .info (
330- f"Stage { rank } has { color .blue } { stage_num_params } params{ color .reset } , Size: { color .blue } { stage_size_formatted } { color .reset } \n "
326+ f"Stage { rank } has { color .blue } { stage_num_params } params{ color .reset } , Size: { color .blue } { stage_size_formatted } { color .reset } "
331327 )
332328
333329 # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
@@ -342,15 +338,9 @@ def main(args):
342338 pp_degree ,
343339 device ,
344340 input_args = (example_args ,),
345- group = pp_mesh . get_group () ,
341+ group = pp_group ,
346342 )
347343
348- # this check confirms that there are no cpu tensors in the model..we expect this to be true.
349- cpu_tensors = find_cpu_tensors (stage .submod )
350- # logger.info(f"Found {len(cpu_tensors)} cpu tensors: {cpu_tensors}")
351- if len (cpu_tensors ) > 0 :
352- raise ValueError ("Found cpu tensors in stage" )
353-
354344 prompt = [
355345 "What is snow?" ,
356346 ]
@@ -374,7 +364,6 @@ def main(args):
374364 ]
375365 """
376366
377-
378367 start_pos = 0
379368
380369 # encode the prompt
@@ -388,88 +377,74 @@ def main(args):
388377 input_ids , tokenizer , seqlen , start_pos , device
389378 )
390379 logger .info (f"{ prompt_lengths = } " )
391- logger .info (f"first prompt { padded_sequence [0 , :prompt_lengths [0 ]+ 1 ]= } " )
392- if len (prompt_lengths ) > 1 :
393- logger .info (f"second prompt { padded_sequence [1 , :prompt_lengths [1 ]+ 1 ]= } " )
394380
381+ # create schedule
395382 schedule = ScheduleGPipe (stage , mbs )
396- logger .info (f"Created schedule: { schedule } " )
397383
398384 # with CUDATrackTime() as timer:
399- first_pp_group = 0
400- last_pp_group = pp_group_size - 1
401-
402- x_recv = torch .zeros (1 , device = device , dtype = torch .int64 )
403- logger .info (f"{ x_recv .shape = } " )
385+ first_pp_rank = 0
386+ last_pp_rank = pp_group_size - 1
404387
405- last_global_rank = world_size - 1
388+ # New token generated each iteration
389+ new_token = torch .zeros (1 , device = device , dtype = torch .int64 )
406390 res = []
407- dst = None
408- src = None
409-
410- if pp_rank == last_pp_group :
411- dst = dist .get_global_rank (pp_group , 0 )
412- elif pp_rank == 0 :
413- src = dist .get_global_rank (pp_group , last_pp_group )
414-
415- # Decoding
416391 num_tokens = 40
417392
393+ # Decoding
418394 with torch .no_grad ():
419395 for step in range (num_tokens ):
420- # first
421- if pp_rank == 0 :
422- schedule .step (padded_sequence )
423- # only receive if not last step
424- if step < num_tokens - 1 :
425- dist .recv (
426- x_recv ,
427- src ,
428- group = pp_group ,
429- )
430- _update_padded_sequence (
431- padded_sequence , x_recv , res , prompt_lengths
432- )
433-
434- # last
435- elif pp_rank == last_pp_group :
396+ # Run data through pipeline
397+ if pp_rank == first_pp_rank :
398+ output = schedule .step (padded_sequence )
399+ elif pp_rank == last_pp_rank :
436400 output = schedule .step ()
437- # need to decode the output
401+ else : # middle pp ranks
402+ schedule .step ()
403+
404+ # Decode the output
405+ if pp_rank == last_pp_rank :
438406 decode_results = _batch_decode_next_tokens (
439407 output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
440408 )
441409 if tp_rank == 0 :
442410 logger .info (
443- f"\n \n { color .green } { 'Prefill' if step == 0 else '* Decode *' } responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
411+ f"{ color .green } { 'Prefill' if step == 0 else '* Decode *' } "
412+ f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
444413 )
445-
446- next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
414+ # decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
415+ new_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
447416 res .append (decode_results [0 ][1 ])
448417
449- # increment prompt lengths for next token
450- for i in range ( len ( prompt_lengths )):
451- prompt_lengths [ i ] += 1
452- # logger.info (
453- # f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
454- # )
455-
456- # only send if not last step
457- if step < ( num_tokens - 1 ) :
458- dist .send (
459- next_token ,
460- dst ,
461- pp_group ,
462- )
418+ # sendrecv between last and first ranks, only if:
419+ # first_pp_rank != last_pp_rank.
420+ if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
421+ dist . send (
422+ new_token ,
423+ dst = dist . get_global_rank ( pp_group , first_pp_rank ),
424+ group = pp_group ,
425+ )
426+ elif pp_rank == first_pp_rank and pp_rank != last_pp_rank :
427+ dist .recv (
428+ new_token ,
429+ src = dist . get_global_rank ( pp_group , last_pp_rank ) ,
430+ group = pp_group ,
431+ )
463432
464- # middle pp ranks
465- else :
466- schedule .step ()
433+ # Update input sequence with new token
434+ if pp_rank == first_pp_rank :
435+ _update_padded_sequence (
436+ padded_sequence , new_token , prompt_lengths
437+ )
438+
439+ # increment prompt lengths for next token
440+ for i in range (len (prompt_lengths )):
441+ prompt_lengths [i ] += 1
467442
468443 # output formatted response via last pp group and tp rank 0
469- if pp_rank == last_pp_group and tp_rank == 0 :
470- logger .info (f"\n Prompt :{ color .green } { prompt [0 ]} { color .reset } " )
471- formatted_response = "" .join (res )
472- logger .info (f"$$$$$$ { color .blue } { formatted_response } \n { color .reset } $$$$$" )
444+ if pp_rank == last_pp_rank and tp_rank == 0 :
445+ logger .info (f"Prompt :{ color .green } { prompt [0 ]} { color .reset } " )
446+ formatted_response = " " .join (res )
447+ logger .info (f"$$$$$$ { color .blue } { formatted_response } { color .reset } $$$$$" )
473448
474449 logger .info (
475450 f"{ color .green } Success{ color .white } - { color .blue } Rank { rank } has completed.{ color .reset } "
0 commit comments