Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
127 changes: 81 additions & 46 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import torch
import torch.distributed as dist
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs

from torchchat.distributed.logging_utils import SingletonLogger

Expand All @@ -33,8 +35,6 @@
get_num_params,
GPUMemoryMonitor,
)
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
from torchchat.model import ModelArgs, Transformer, TransformerArgs
from torchchat.utils.build_utils import set_precision

Expand Down Expand Up @@ -189,23 +189,49 @@ def _create_padded_prompts(

def _batch_decode_next_tokens(
output: torch.Tensor,
pos: int,
pos: List[int],
step: int = -1,
temperature: float = 1.0,
topk: int = 10,
) -> torch.Tensor:
"""
Decode the next token for each prompt in the batch.
Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if torchchat's generate also have the temperature option? Shall we think about how to connect with generate in next steps?


Args:
output (torch.Tensor): The output tensor to decode.
pos: the position of the `output` to decode in the sequence length dimension.
pos (List[int]): The positions of the `output` to decode in the sequence length dimension.
step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token.
temperature (float): Sampling temperature for non-deterministic decoding.

Returns:
Decoded token ids.
torch.Tensor: Decoded token ids.
"""
# Take the next token logits for each prompt
next_token_logits = output[:, pos, :]
# Argmax (deterministic) TODO: add temperature
next_token = torch.argmax(next_token_logits, dim=-1)
# Token ids in int tensor form
return next_token
batch_size, seq_len, vocab_size = output.shape

if step != -1:
Copy link
Contributor

@kwen2501 kwen2501 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can step == -1 be represented by pos = [] or pos = [0, 0, ...]? (saving one argument)

next_token_logits = output[:, 0, :]
else:
# get the logits for each prompt at the specified positions
next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why "-1"?
From this function's perspective, if the caller has given the position, should it just faithfully decode that position?
(I understand that this can be run right if providing prompt_length instead of promt_length -1 at callsite.)


if temperature != 1.0:
next_token_logits = next_token_logits / temperature
Comment on lines +217 to +218
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we do the division unconditionally?


# Uses top-k sampling if temperature is not 1.0, otherwise use argmax
if temperature != 1.0:
top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size
top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
probs = torch.softmax(top_k_logits, dim=-1)
next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1)
next_tokens = top_k_indices.gather(
-1, next_token_indices.unsqueeze(-1)
).squeeze(-1)
Comment on lines +222 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you mind adding more comments here for the multinomial, gather, squeeze and unsqueeze ops?

else:
# Argmax (deterministic)
next_tokens = torch.argmax(next_token_logits, dim=-1)

logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
return next_tokens


def _update_padded_sequence(
Expand All @@ -218,11 +244,32 @@ def _update_padded_sequence(
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")


# Decode token id into string and print it
def _decode_in_flight(token, tokenizer, tp_rank):
"""decode token ids for all prompts in the batch and log them"""
token_str = tokenizer.decode(token.tolist())
# print the token string on tp rank 0
if tp_rank == 0:
logger.info(
f"{color.green} responses ====>>>> "
f"{color.blue} {token_str} {color.reset}"
)


def _cleanup():
dist.barrier()
dist.destroy_process_group()


prompt = [
"What is Snow?",
"Who is Santa Claus?",
"Where does Santa live?",
# "Who is Abraham Lincoln?",
# "How are models trained?",
]


def main(args):
model_name = args.model_name
pp_degree = args.pp
Expand Down Expand Up @@ -293,7 +340,7 @@ def main(args):
# Batch size. Since we push batches dynamically through the pipeline rather
# than chunking them, this is effectively micro-batch size in pipeline
# sense. Thus it is interchangeable with micro-batch size below.
batch_size = 4
batch_size = len(prompt)
seqlen_prefill = 1024 # sequence length
dim = 4096 # embedding dimension

Expand Down Expand Up @@ -331,7 +378,9 @@ def main(args):

# Helper function to get example inputs and outputs for the stages.
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
mb_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), device=device
)
activation = torch.rand(
batch_size, seqlen, dim, device=device, dtype=model_dtype
)
Expand Down Expand Up @@ -362,13 +411,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# pipelining effect.
prefiller = ScheduleGPipe(prefill_stage, 1)

prompt = [
"What is a computer?",
"Where does Santa live?",
"Who is Abraham Lincoln?",
"How are models trained?",
]

start_pos = 0

# Need these global ids due to the API definition of dist.send and recv
Expand All @@ -384,10 +426,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
padded_sequence, prompt_lengths = _create_padded_prompts(
input_ids, tokenizer, seqlen_prefill, start_pos, device
)
# TODO: figure out how to set input_pos for each prompt in the batch then we
# can remove this limitation.
s = set(prompt_lengths)
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"

# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
Expand All @@ -396,6 +434,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# New token generated each iteration
# need a row dimension for each prompt in the batch
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for debugging only?

# Store the generated tokens
res = []

Expand All @@ -416,23 +455,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Decode token id into string and print it
def decode_in_flight(token):
# Make a 2D tensor with ids on row dimension
unsqueezed = torch.unsqueeze(token, 1)
token_str = tokenizer.decode(unsqueezed.tolist())
if tp_rank == 0:
logger.info(
f"{color.green} responses ====>>>> "
f"{color.blue} {token_str} {color.reset}"
)

# Decode the output -- first generated token
if pp_rank == last_pp_rank:
new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1)
logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}")
new_token = _batch_decode_next_tokens(output, prompt_lengths)
res.append(new_token)
if not args.disable_in_flight_decode:
decode_in_flight(new_token)
_decode_in_flight(new_token, tokenizer, tp_rank)
Comment on lines 463 to +464
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put tp_rank to the if condition?


# seqlen = 1 now
seqlen_decode = 1
Expand Down Expand Up @@ -482,10 +511,11 @@ def decode_in_flight(token):

# Decode the output
if pp_rank == last_pp_rank:
new_token = _batch_decode_next_tokens(output, 0)
# logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove log?

new_token = _batch_decode_next_tokens(output, prompt_lengths, step)
res.append(new_token)
if not args.disable_in_flight_decode:
decode_in_flight(new_token)
_decode_in_flight(new_token, tokenizer, tp_rank)

# Increment input position
input_pos += 1
Expand All @@ -499,12 +529,17 @@ def decode_in_flight(token):
# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
# `res` is a list of tensors, each being a batch of generated token ids
res = torch.stack(res, dim=1)
res_list = res.tolist()
response = tokenizer.decode(res_list)
for i in range(len(response)):
logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}")
logger.info(f"Response: {color.red}{response[i]} {color.reset}")

res_stacked = torch.stack(res, dim=1)
res_list = res_stacked.tolist()

# Decode the output as comprehension instead of loop
responses = [tokenizer.decode(sequence) for sequence in res_list]
Comment on lines +536 to +537
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, did the previous code not work in case of variable length? Just curious.
response = tokenizer.decode(res_list)


# Show prompts and responses
for prompt_text, response_text in zip(prompt, responses):
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
logger.info(f"Response: {color.red}{response_text} {color.reset}")

# Cleanup
_cleanup()
Expand Down
Loading