Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 112 additions & 78 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def _create_padded_prompts(

def _batch_decode_next_tokens(
output: torch.Tensor,
prompt_lengths: List[int],
tokenizer,
prompt_lengths: Optional[List[int]] = None,
) -> List[Tuple[int, str]]:
"""
Decode the next token for each prompt in the batch.
Expand All @@ -201,7 +201,8 @@ def _batch_decode_next_tokens(
results = []

for i in range(batch_size):
next_token_logits = output[i, prompt_lengths[i] - 1, :]
pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0
next_token_logits = output[i, pos, :]

# Argmax (deterministic) TODO: add temperature
next_token = torch.argmax(next_token_logits, dim=-1)
Expand Down Expand Up @@ -276,6 +277,10 @@ def main(args):
tp_group_size = tp_group.size()
logger.info(f"{pp_group_size=}, {tp_group_size=}")

# Convenience variables
first_pp_rank = 0
last_pp_rank = pp_group_size - 1

# Assuming same number of GPUs per node
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

Expand All @@ -293,29 +298,23 @@ def main(args):
logger.info(f"Model: {model}")

mbs = 1 # number of micro-batches
mb_size = 5 # micro-batch size
mb_size = 4 # micro-batch size
batch_size = mbs * mb_size # total batch size

seqlen = 4096 # sequence length
seqlen_prefill = 1024 # sequence length
dim = 4096 # embedding dimension

# Setup KV caches (after model distribution)
# TODO: the setting below only works for 1 micro-batch case. To support
# multiple micro-batches, we need the KV cache in the model to be aware of
# the number of micro-batches and the current micro-batch index.
model.setup_caches(mb_size, seqlen)

mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
activation = torch.rand(
mb_size, seqlen, dim, device=device, dtype=model_dtype
)
example_args = mb_ids if pp_rank == 0 else activation
model.setup_caches(mb_size, seqlen_prefill)

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")

with CUDATrackTime() as timer:
_load_model_weights(model, distribution, device=device, model_config=config)
model.to(device)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
Expand All @@ -330,53 +329,47 @@ def main(args):
)

# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
input_pos = torch.arange(seqlen, device=device)
input_pos = torch.arange(seqlen_prefill, device=device)
model.setup_input_pos(input_pos)
model.eval()

logger.info(f"Creating pipeline stage {pp_rank=}, {pp_degree=}")
stage = PipelineStage(
# Helper function to get example inputs and outputs for the stages.
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
activation = torch.rand(
mb_size, seqlen, dim, device=device, dtype=model_dtype
)
logits = torch.rand(
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
)
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,)
example_outputs = (logits if pp_rank == last_pp_rank else activation,)
return example_inputs, example_outputs

# Create prefill stage
logger.info(f"Creating pipeline stage for prefill {pp_rank=}, {pp_degree=}")
example_inputs, example_outputs = get_example_ins_outs(seqlen_prefill)
prefill_stage = PipelineStage(
model,
pp_rank,
pp_degree,
device,
input_args=(example_args,),
input_args=example_inputs,
output_args=example_outputs,
group=pp_group,
)
# create schedule
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)

prompt = [
"What is snow?",
"Where does Santa Claus live?",
"What is PyTorch?",
"Write a poem about the beauty of the night sky.",
"What is the capital of France, Germany and Switzerland?",
]

"""
"What is the capital of France?",
"What is your name?",
"What is the capital of Japan?",
"When is Christmas?",
"Where does Santa Claus live?",
"What is the capital of the United States?",
"What is the capital of China?",
"What is the capital of Russia?",
"What is PyTorch?",
"What is the capital of India?",
"What is an LLM?",
"What is the capital of Brazil?",
"What is the capital of Mexico?",
"What is the capital of Argentina?",
"What is the capital of Canada?",
"What is a computer?",
"Where does Santa live?",
"Who is Abraham Lincoln?",
"How are models trained?",
]
"""

start_pos = 0

# pipeline comms setup
first_pp_rank = 0
last_pp_rank = pp_group_size - 1

# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
Expand All @@ -388,15 +381,14 @@ def main(args):

# create a padded tensor for the input prompt
padded_sequence, prompt_lengths = _create_padded_prompts(
input_ids, tokenizer, seqlen, start_pos, device
input_ids, tokenizer, seqlen_prefill, start_pos, device
)

# create schedule
schedule = ScheduleGPipe(stage, mbs)
# TODO: figure out how to set input_pos for each prompt in the batch then we
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this limitation is probably our most important next step.

# can remove this limitation.
s = set(prompt_lengths)
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"

# with CUDATrackTime() as timer:
first_pp_rank = 0
last_pp_rank = pp_group_size - 1
# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
Expand All @@ -408,25 +400,87 @@ def main(args):
res = [[] for _ in range(total_prompts)]
num_tokens = 40

# Prefill phase
# Run context input through pipeline, in 1 step
with torch.no_grad():
if pp_rank == first_pp_rank:
output = prefill_schedule.step(padded_sequence)
elif pp_rank == last_pp_rank:
output = prefill_schedule.step()
else: # middle pp ranks
prefill_schedule.step()

# Decode the output -- first generated token
if pp_rank == last_pp_rank:
decode_results = _batch_decode_next_tokens(
output=output,
tokenizer=tokenizer,
prompt_lengths=prompt_lengths,
)
for i in range(len(decode_results)):
new_token[i, 0] = torch.tensor(
[decode_results[i][0]], device=device
) # token_id in int form
if tp_rank == 0:
logger.info(
f"{color.green} {'* Prefill *'} "
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
)

# seqlen = 1 now
seqlen_decode = 1
input_pos = torch.tensor([prompt_lengths[0]], device=device)
model.setup_input_pos(input_pos)

# Create decode stage
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
example_inputs, example_outputs = get_example_ins_outs(seqlen_decode)
decode_stage = PipelineStage(
model,
pp_rank,
pp_degree,
device,
input_args=example_inputs,
output_args=example_outputs,
group=pp_group,
)
# create schedule
decode_schedule = ScheduleGPipe(decode_stage, mbs)

# Decoding
with torch.no_grad():
for step in range(num_tokens):
for step in range(num_tokens - 1):
# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
dist.send(
new_token,
dst=first_pp_rank_global_id,
group=pp_group,
)
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
dist.recv(
new_token,
src=last_pp_rank_global_id,
group=pp_group,
)

# Run data through pipeline
if pp_rank == first_pp_rank:
output = schedule.step(padded_sequence)
output = decode_schedule.step(new_token)
elif pp_rank == last_pp_rank:
output = schedule.step()
output = decode_schedule.step()
else: # middle pp ranks
schedule.step()
decode_schedule.step()

# Decode the output
if pp_rank == last_pp_rank:
decode_results = _batch_decode_next_tokens(
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
output=output, tokenizer=tokenizer
)
if tp_rank == 0:
logger.info(
f"{color.green} {'Prefill' if step == 0 else '* Decode *'} "
f"{color.green} {'* Decode *'} "
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
)
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
Expand All @@ -436,28 +490,8 @@ def main(args):
[decode_results[i][0]], device=device
) # decode_results[i][0]

# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
dist.send(
new_token,
dst=first_pp_rank_global_id,
group=pp_group,
)
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
dist.recv(
new_token,
src=last_pp_rank_global_id,
group=pp_group,
)

# Update input sequence with new token
if pp_rank == first_pp_rank:
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)

# increment prompt lengths for next token
for i in range(len(prompt_lengths)):
prompt_lengths[i] += 1
input_pos += 1
model.setup_input_pos(input_pos)

# Display the decoding results

Expand Down
Loading