Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4ecb951

Browse files
committed
Remove setup_input_pos
1 parent 2bec61c commit 4ecb951

File tree

2 files changed

+9
-16
lines changed

2 files changed

+9
-16
lines changed

dist_run.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def main(args):
332332

333333
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
334334
input_pos = torch.arange(seqlen_prefill, device=device)
335-
model.setup_input_pos(input_pos)
336335
model.eval()
337336

338337
# Helper function to get example inputs and outputs for the stages.
@@ -410,13 +409,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
410409
# Prefill phase
411410
# Run context input through pipeline
412411
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
412+
lane = 0
413+
kwargs = {"input_pos": input_pos, "cache_lane": lane}
413414
with torch.no_grad():
414415
if pp_rank == first_pp_rank:
415-
output = prefiller.step(padded_sequence)
416+
output = prefiller.step(padded_sequence, **kwargs)
416417
elif pp_rank == last_pp_rank:
417-
output = prefiller.step()
418+
output = prefiller.step(**kwargs)
418419
else: # middle pp ranks
419-
prefiller.step()
420+
prefiller.step(**kwargs)
420421

421422
# Decode the output -- first generated token
422423
if pp_rank == last_pp_rank:
@@ -438,7 +439,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
438439
# seqlen = 1 now
439440
seqlen_decode = 1
440441
input_pos = torch.tensor([prompt_lengths[0]], device=device)
441-
model.setup_input_pos(input_pos)
442442

443443
# Create decode stage
444444
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
@@ -458,6 +458,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
458458
# Decoding
459459
with torch.no_grad():
460460
for step in range(num_tokens - 1):
461+
kwargs = {"input_pos": input_pos, "cache_lane": lane}
461462
# sendrecv between last and first ranks, only if:
462463
# first_pp_rank != last_pp_rank.
463464
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
@@ -475,11 +476,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
475476

476477
# Run data through pipeline
477478
if pp_rank == first_pp_rank:
478-
output = decorder.step(new_token)
479+
output = decorder.step(new_token, **kwargs)
479480
elif pp_rank == last_pp_rank:
480-
output = decorder.step()
481+
output = decorder.step(**kwargs)
481482
else: # middle pp ranks
482-
decorder.step()
483+
decorder.step(**kwargs)
483484

484485
# Decode the output
485486
if pp_rank == last_pp_rank:
@@ -499,7 +500,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
499500
) # decode_results[i][0]
500501

501502
input_pos += 1
502-
model.setup_input_pos(input_pos)
503503

504504
# Display the decoding results
505505

torchchat/model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -653,15 +653,8 @@ def distribute(self, device_mesh: DeviceMesh):
653653
ColwiseParallel(output_layouts=Replicate()),
654654
)
655655

656-
# This is a temporary solution to pass input_pos to non-0 pipeline stages
657-
# TODO: make `step()` function of dist.pipelining accept args for non-0 stages
658-
def setup_input_pos(self, input_pos: Tensor) -> None:
659-
self._input_pos = input_pos
660-
661656
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
662657
assert self.freqs_cis is not None, "Caches must be initialized first"
663-
# TODO: find a better way to pass input_pos to non-0 pipeline stages
664-
input_pos = input_pos if input_pos is not None else self._input_pos
665658
mask = self.causal_mask[None, None, input_pos]
666659
freqs_cis = self.freqs_cis[input_pos]
667660
if self.tok_embeddings:

0 commit comments

Comments
 (0)