Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1ea7960

Browse files
committed
enable multi-batch prefill
1 parent 50d451a commit 1ea7960

File tree

1 file changed

+104
-50
lines changed

1 file changed

+104
-50
lines changed

dist_run.py

Lines changed: 104 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,23 @@
77
import os
88
from pathlib import Path
99
from types import SimpleNamespace
10-
from typing import Any, Dict, Tuple
10+
from typing import Any, Dict, List, Optional, Tuple
1111

1212
# Run command:
1313
# torchrun --nproc-per-node 4 dist_run.py
1414
import torch
1515
import torch.distributed as dist
1616
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
1717

18-
1918
from distributed.logging_utils import SingletonLogger
20-
2119
# TODO - these are not distributed specific, consider moving to new package
22-
from distributed.safetensor_utils import (
23-
get_hf_config_file,
24-
get_hf_weight_map_and_path,
25-
load_safetensor_weights,
26-
)
27-
28-
from distributed.utils import (
29-
Color as color,
30-
GPUMemoryMonitor,
31-
get_module_size,
32-
get_num_params,
33-
bytes_to_readable,
34-
TrackTime,
35-
CUDATrackTime,
36-
)
37-
20+
from distributed.safetensor_utils import (get_hf_config_file,
21+
get_hf_weight_map_and_path,
22+
load_safetensor_weights)
23+
from distributed.utils import Color as color
24+
from distributed.utils import (GPUMemoryMonitor, TrackTime,
25+
bytes_to_readable, get_module_size,
26+
get_num_params)
3827
from distributed.verification_utils import find_cpu_tensors
3928
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
4029
from torchchat.model import ModelArgs, Transformer
@@ -123,28 +112,86 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config):
123112
raise ValueError(f"Missing {num_missing_weights} weights")
124113

125114

126-
def _encode_string(
127-
string: str, tokenizer, bos: bool = True, device: str = "cuda", dtype=torch.int64
128-
) -> torch.Tensor:
129-
"""Encode a prompt string into a tensor of token ids."""
130-
tokens = tokenizer.encode(string)
131-
if bos:
132-
tokens = [tokenizer.bos_id()] + tokens
133-
return torch.tensor(tokens, dtype=dtype, device=device)
134-
135-
136-
def _create_padded_prompt(
137-
input_ids: torch.Tensor, tokenizer, seqlen: int, start_pos: int, device: str
138-
) -> Tuple[torch.Tensor, int]:
139-
"""Create a padded tensor for the encoded input prompt. Returns the padded tensor and the prompt length."""
140-
prompt_len = input_ids.size(0)
141-
max_new_tokens = min(seqlen, seqlen - start_pos - prompt_len)
142-
token_buffer_size = prompt_len + max_new_tokens
143-
seq = torch.full(
144-
(1, token_buffer_size), tokenizer.eos_id(), dtype=torch.int64, device=device
115+
def _encode_strings(
116+
strings: List[str],
117+
tokenizer,
118+
bos: bool = True,
119+
device: str = "cuda",
120+
dtype=torch.int64,
121+
) -> List[torch.Tensor]:
122+
"""Encode a list of prompt strings into a list of tensor token ids."""
123+
encoded_list = []
124+
for string in strings:
125+
tokens = tokenizer.encode(string)
126+
if bos:
127+
tokens = [tokenizer.bos_id()] + tokens
128+
encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device))
129+
return encoded_list
130+
131+
132+
def _create_padded_prompts(
133+
input_ids_list: List[torch.Tensor],
134+
tokenizer,
135+
seqlen: int,
136+
start_pos: int,
137+
device: str,
138+
pad_token_id: Optional[int] = None,
139+
) -> Tuple[torch.Tensor, List[int]]:
140+
"""
141+
Create a padded tensor for multiple encoded input prompts.
142+
143+
Returns:
144+
Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths.
145+
"""
146+
pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id()
147+
148+
# Find the maximum prompt length
149+
max_prompt_len = max(ids.size(0) for ids in input_ids_list)
150+
151+
# Calculate the buffer size
152+
max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len))
153+
token_buffer_size = max_prompt_len + max_new_tokens
154+
155+
# Create the padded batch tensor
156+
batch_size = len(input_ids_list)
157+
batch_seq = torch.full(
158+
(batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device
145159
)
146-
seq[0, :prompt_len] = input_ids
147-
return seq, prompt_len
160+
161+
prompt_lengths = []
162+
for i, input_ids in enumerate(input_ids_list):
163+
prompt_len = input_ids.size(0)
164+
batch_seq[i, :prompt_len] = input_ids
165+
prompt_lengths.append(prompt_len)
166+
167+
return batch_seq, prompt_lengths
168+
169+
170+
def _batch_decode_next_tokens(
171+
output: torch.Tensor,
172+
prompt_lengths: List[int],
173+
tokenizer,
174+
) -> List[Tuple[int, str]]:
175+
"""
176+
Decode the next token for each prompt in the batch.
177+
178+
Returns:
179+
List[Tuple[int, str]]: List of tuples containing the next token id and its
180+
decoded string for each prompt in the batch.
181+
"""
182+
batch_size = output.shape[0]
183+
results = []
184+
185+
for i in range(batch_size):
186+
next_token_logits = output[i, prompt_lengths[i] - 1, :]
187+
188+
# Argmax (deterministic) TODO: add temperature
189+
next_token = torch.argmax(next_token_logits, dim=-1)
190+
191+
next_token_decoded = tokenizer.decode([next_token.item()])
192+
results.append((next_token.item(), next_token_decoded))
193+
194+
return results
148195

149196

150197
def _cleanup():
@@ -209,7 +256,7 @@ def main():
209256
model.distribute(tp_mesh)
210257
# logger.info(f"Model: {model}")
211258

212-
mbs = 1 # number of micro-batches TODO: move to multibatch
259+
mbs = 4 # number of micro-batches
213260
mb_size = 1 # micro-batch size
214261
batch_size = mbs * mb_size # total batch size
215262

@@ -260,21 +307,28 @@ def main():
260307
if len(cpu_tensors) > 0:
261308
raise ValueError("Found cpu tensors in stage")
262309

263-
prompt = "What is the capital of France?"
310+
prompt = [
311+
"What is the capital of France?",
312+
"What is snow?",
313+
"What is your name?",
314+
"What is the capital of Japan?",
315+
]
264316
start_pos = 0
265317

266318
# encode the prompt
267-
input_ids = _encode_string(
319+
input_ids = _encode_strings(
268320
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
269321
)
270322
logger.info(f"{input_ids[0:8]=}")
271323

272324
# create a padded tensor for the input prompt
273-
padded_sequence, prompt_len = _create_padded_prompt(
325+
padded_sequence, prompt_lengths = _create_padded_prompts(
274326
input_ids, tokenizer, seqlen, start_pos, device
275327
)
276-
logger.info(f"{prompt_len=}")
277-
logger.info(f"{padded_sequence[0, :prompt_len+1]=}")
328+
logger.info(f"{prompt_lengths=}")
329+
logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}")
330+
if len(prompt_lengths) > 1:
331+
logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}")
278332

279333
schedule = ScheduleGPipe(stage, mbs)
280334
logger.info(f"Created schedule: {schedule}")
@@ -287,12 +341,12 @@ def main():
287341

288342
# Decoding
289343
if pp_rank == pp_degree - 1 and tp_rank == 0:
290-
next_token_logits = output[:, prompt_len - 1, :]
291-
next_token = torch.argmax(next_token_logits, dim=-1)
292-
next_token_decoded = tokenizer.decode((next_token.tolist()))
344+
decode_results = _batch_decode_next_tokens(
345+
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
346+
)
293347

294348
logger.info(
295-
f"\n\n{color.green} Prefill response ====>>>> {color.blue} {next_token_decoded=}, {next_token}\n{color.reset}"
349+
f"\n\n{color.green} Prefill responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
296350
)
297351

298352
# show peak memory stats for this stage

0 commit comments

Comments
 (0)