Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 60 additions & 85 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def _load_model_weights(stage_module, distribution, device, model_config):
def _encode_strings(
strings: List[str],
tokenizer,
bos: bool = True,
device: torch.device = "cuda:0",
bos: bool,
device: torch.device,
dtype=torch.int64,
) -> List[torch.Tensor]:
"""Encode a list of prompt strings into a list of tensor token ids."""
Expand Down Expand Up @@ -216,13 +216,13 @@ def _batch_decode_next_tokens(

def _update_padded_sequence(
padded_sequence: torch.Tensor,
x_recv: torch.Tensor,
res,
new_token: torch.Tensor,
prompt_lengths: List[int],
) -> None:
# TODO: this is a hacky way to update the padded sequence: when there is
# more than one prompt, the for loop and the assignment is incompatible.
for i in range(len(prompt_lengths)):
prompt_lengths[i] += 1
padded_sequence[i, prompt_lengths[i] - 1] = x_recv
padded_sequence[i, prompt_lengths[i]] = new_token


def _cleanup():
Expand Down Expand Up @@ -267,19 +267,15 @@ def main(args):
device_mesh = _create_device_mesh(mesh_dimensions)
tp_mesh = device_mesh["tp"]
pp_mesh = device_mesh["pp"]
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")

tp_rank = tp_mesh.get_local_rank()
pp_rank = pp_mesh.get_local_rank()
tp_group = tp_mesh.get_group()
pp_group = pp_mesh.get_group()

logger.info(f"review: {pp_group=}, {tp_group= }")

logger.info(f"Created device mesh: {device_mesh}\n {tp_mesh=}, {pp_mesh=}\n")
# TODO - this assumes 1D mesh, need to update for 2D+ mesh
pp_group_size = pp_mesh.size()
tp_group_size = tp_mesh.size()

logger.info(f"pp_group_size: {pp_group_size}, tp_group_size: {tp_group_size}")
pp_group_size = pp_group.size()
tp_group_size = tp_group.size()
logger.info(f"{pp_group_size=}, {tp_group_size=}")

# Assuming same number of GPUs per node
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
Expand Down Expand Up @@ -316,7 +312,7 @@ def main(args):
logger.info(f"Loading weights for {pp_rank=} on {device=}")

with CUDATrackTime() as timer:
_load_model_weights(model, hf_model_name, device=device, model_config=config)
_load_model_weights(model, distribution, device=device, model_config=config)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
Expand All @@ -327,7 +323,7 @@ def main(args):
stage_size_formatted = bytes_to_readable(stage_size)
stage_num_params = get_num_params(model)
logger.info(
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n"
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
)

# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
Expand All @@ -342,15 +338,9 @@ def main(args):
pp_degree,
device,
input_args=(example_args,),
group=pp_mesh.get_group(),
group=pp_group,
)

# this check confirms that there are no cpu tensors in the model..we expect this to be true.
cpu_tensors = find_cpu_tensors(stage.submod)
# logger.info(f"Found {len(cpu_tensors)} cpu tensors: {cpu_tensors}")
if len(cpu_tensors) > 0:
raise ValueError("Found cpu tensors in stage")

prompt = [
"What is snow?",
]
Expand All @@ -374,7 +364,6 @@ def main(args):
]
"""


start_pos = 0

# encode the prompt
Expand All @@ -388,88 +377,74 @@ def main(args):
input_ids, tokenizer, seqlen, start_pos, device
)
logger.info(f"{prompt_lengths=}")
logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}")
if len(prompt_lengths) > 1:
logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}")

# create schedule
schedule = ScheduleGPipe(stage, mbs)
logger.info(f"Created schedule: {schedule}")

# with CUDATrackTime() as timer:
first_pp_group = 0
last_pp_group = pp_group_size - 1

x_recv = torch.zeros(1, device=device, dtype=torch.int64)
logger.info(f"{x_recv.shape=}")
first_pp_rank = 0
last_pp_rank = pp_group_size - 1

last_global_rank = world_size - 1
# New token generated each iteration
new_token = torch.zeros(1, device=device, dtype=torch.int64)
res = []
dst = None
src = None

if pp_rank == last_pp_group:
dst = dist.get_global_rank(pp_group, 0)
elif pp_rank == 0:
src = dist.get_global_rank(pp_group, last_pp_group)

# Decoding
num_tokens = 40

# Decoding
with torch.no_grad():
for step in range(num_tokens):
# first
if pp_rank == 0:
schedule.step(padded_sequence)
# only receive if not last step
if step < num_tokens - 1:
dist.recv(
x_recv,
src,
group=pp_group,
)
_update_padded_sequence(
padded_sequence, x_recv, res, prompt_lengths
)

# last
elif pp_rank == last_pp_group:
# Run data through pipeline
if pp_rank == first_pp_rank:
output = schedule.step(padded_sequence)
elif pp_rank == last_pp_rank:
output = schedule.step()
# need to decode the output
else: # middle pp ranks
schedule.step()

# Decode the output
if pp_rank == last_pp_rank:
decode_results = _batch_decode_next_tokens(
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
)
if tp_rank == 0:
logger.info(
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
f"{color.green} {'Prefill' if step == 0 else '* Decode *'} "
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
)

next_token = torch.tensor([decode_results[0][0]], device=device)
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
new_token = torch.tensor([decode_results[0][0]], device=device)
res.append(decode_results[0][1])

# increment prompt lengths for next token
for i in range(len(prompt_lengths)):
prompt_lengths[i] += 1
# logger.info(
# f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
# )

# only send if not last step
if step < (num_tokens - 1):
dist.send(
next_token,
dst,
pp_group,
)
# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
dist.send(
new_token,
dst=dist.get_global_rank(pp_group, first_pp_rank),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dst and src shouldn't be getting recreated every iter, since they don't change on a per iter basis.
This is why I had moved them out of the loop previously.
Not sure how expensive the dist.get_global_rank is but no need imo to be calling it over and over here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Will upload a PR to improve it. Thanks!

group=pp_group,
)
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
dist.recv(
new_token,
src=dist.get_global_rank(pp_group, last_pp_rank),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above w/dst, why do we call this api over and over to get src within the loop instead of once out of the loop?

group=pp_group,
)

# middle pp ranks
else:
schedule.step()
# Update input sequence with new token
if pp_rank == first_pp_rank:
_update_padded_sequence(
padded_sequence, new_token, prompt_lengths
)

# increment prompt lengths for next token
for i in range(len(prompt_lengths)):
prompt_lengths[i] += 1

# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_group and tp_rank == 0:
logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}")
formatted_response = "".join(res)
logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$")
if pp_rank == last_pp_rank and tp_rank == 0:
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
formatted_response = " ".join(res)
logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$")

logger.info(
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
Expand Down
Loading