@@ -332,7 +332,6 @@ def main(args):
332332
333333 # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
334334 input_pos = torch .arange (seqlen_prefill , device = device )
335- model .setup_input_pos (input_pos )
336335 model .eval ()
337336
338337 # Helper function to get example inputs and outputs for the stages.
@@ -410,13 +409,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
410409 # Prefill phase
411410 # Run context input through pipeline
412411 # TODO: we need to pass `input_pos` and `cache_lane` to each stage.
412+ lane = 0
413+ kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
413414 with torch .no_grad ():
414415 if pp_rank == first_pp_rank :
415- output = prefiller .step (padded_sequence )
416+ output = prefiller .step (padded_sequence , ** kwargs )
416417 elif pp_rank == last_pp_rank :
417- output = prefiller .step ()
418+ output = prefiller .step (** kwargs )
418419 else : # middle pp ranks
419- prefiller .step ()
420+ prefiller .step (** kwargs )
420421
421422 # Decode the output -- first generated token
422423 if pp_rank == last_pp_rank :
@@ -438,7 +439,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
438439 # seqlen = 1 now
439440 seqlen_decode = 1
440441 input_pos = torch .tensor ([prompt_lengths [0 ]], device = device )
441- model .setup_input_pos (input_pos )
442442
443443 # Create decode stage
444444 logger .info (f"Creating pipeline stage for decode { pp_rank = } , { pp_degree = } " )
@@ -458,6 +458,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
458458 # Decoding
459459 with torch .no_grad ():
460460 for step in range (num_tokens - 1 ):
461+ kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
461462 # sendrecv between last and first ranks, only if:
462463 # first_pp_rank != last_pp_rank.
463464 if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
@@ -475,11 +476,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
475476
476477 # Run data through pipeline
477478 if pp_rank == first_pp_rank :
478- output = decorder .step (new_token )
479+ output = decorder .step (new_token , ** kwargs )
479480 elif pp_rank == last_pp_rank :
480- output = decorder .step ()
481+ output = decorder .step (** kwargs )
481482 else : # middle pp ranks
482- decorder .step ()
483+ decorder .step (** kwargs )
483484
484485 # Decode the output
485486 if pp_rank == last_pp_rank :
@@ -499,7 +500,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
499500 ) # decode_results[i][0]
500501
501502 input_pos += 1
502- model .setup_input_pos (input_pos )
503503
504504 # Display the decoding results
505505
0 commit comments