Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a645f8e

Browse files
lessw2020Jack-Khuu
andauthored
[distributed] prefill (single and multi-prompt) and single prompt generation and decoding (#1133)
* add encode string and create_padded_input functions * update prefill functions * ensure 8B is default * add typing to added functions * ruff formatting * enable multi-batch prefill * decoding start * decoding comms working * decoding comms working next token send/receive * first decoded token * second decoded token * single prompt prefill + decoding all working * add _update_padded_sequence * add refined output, update force_download to 3-8B * pr_feedback, ruff formatting * remove debug related logging leaving as commented out as will reuse for multi-prompt work. --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent 7a4f0d1 commit a645f8e

File tree

3 files changed

+236
-46
lines changed

3 files changed

+236
-46
lines changed

dist_run.py

Lines changed: 231 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
import os
99
from pathlib import Path
1010
from types import SimpleNamespace
11-
from typing import Any, Dict, Optional
11+
from typing import Any, Dict, List, Optional, Tuple
1212

1313
# Run command:
1414
# torchrun --nproc-per-node 4 dist_run.py
1515
import torch
1616
import torch.distributed as dist
17-
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
18-
1917

2018
from distributed.logging_utils import SingletonLogger
2119

@@ -25,19 +23,17 @@
2523
get_hf_weight_map_and_path,
2624
load_safetensor_weights,
2725
)
28-
2926
from distributed.utils import (
27+
bytes_to_readable,
3028
Color as color,
31-
GPUMemoryMonitor,
29+
CUDATrackTime,
3230
get_module_size,
3331
get_num_params,
34-
bytes_to_readable,
35-
TrackTime,
36-
CUDATrackTime,
32+
GPUMemoryMonitor,
3733
)
38-
3934
from distributed.verification_utils import find_cpu_tensors
40-
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
35+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
36+
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
4137
from torchchat.model import ModelArgs, Transformer
4238
from torchchat.utils.build_utils import set_precision
4339

@@ -136,6 +132,99 @@ def _load_model_weights(stage_module, distribution, device, model_config):
136132
raise ValueError(f"Missing {num_missing_weights} weights")
137133

138134

135+
def _encode_strings(
136+
strings: List[str],
137+
tokenizer,
138+
bos: bool = True,
139+
device: torch.device = "cuda:0",
140+
dtype=torch.int64,
141+
) -> List[torch.Tensor]:
142+
"""Encode a list of prompt strings into a list of tensor token ids."""
143+
encoded_list = []
144+
for string in strings:
145+
tokens = tokenizer.encode(string)
146+
if bos:
147+
tokens = [tokenizer.bos_id()] + tokens
148+
encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device))
149+
return encoded_list
150+
151+
152+
def _create_padded_prompts(
153+
input_ids_list: List[torch.Tensor],
154+
tokenizer,
155+
seqlen: int,
156+
start_pos: int,
157+
device: torch.device,
158+
pad_token_id: Optional[int] = None,
159+
) -> Tuple[torch.Tensor, List[int]]:
160+
"""
161+
Create a padded tensor for multiple encoded input prompts.
162+
163+
Returns:
164+
Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths.
165+
"""
166+
pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id()
167+
168+
# Find the maximum prompt length
169+
max_prompt_len = max(ids.size(0) for ids in input_ids_list)
170+
171+
# Calculate the buffer size
172+
max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len))
173+
token_buffer_size = max_prompt_len + max_new_tokens
174+
175+
# Create the padded batch tensor
176+
batch_size = len(input_ids_list)
177+
batch_seq = torch.full(
178+
(batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device
179+
)
180+
181+
prompt_lengths = []
182+
for i, input_ids in enumerate(input_ids_list):
183+
prompt_len = input_ids.size(0)
184+
batch_seq[i, :prompt_len] = input_ids
185+
prompt_lengths.append(prompt_len)
186+
187+
return batch_seq, prompt_lengths
188+
189+
190+
def _batch_decode_next_tokens(
191+
output: torch.Tensor,
192+
prompt_lengths: List[int],
193+
tokenizer,
194+
) -> List[Tuple[int, str]]:
195+
"""
196+
Decode the next token for each prompt in the batch.
197+
198+
Returns:
199+
List[Tuple[int, str]]: List of tuples containing the next token id and its
200+
decoded string for each prompt in the batch.
201+
"""
202+
batch_size = output.shape[0]
203+
results = []
204+
205+
for i in range(batch_size):
206+
next_token_logits = output[i, prompt_lengths[i] - 1, :]
207+
208+
# Argmax (deterministic) TODO: add temperature
209+
next_token = torch.argmax(next_token_logits, dim=-1)
210+
211+
next_token_decoded = tokenizer.decode([next_token.item()])
212+
results.append((next_token.item(), next_token_decoded))
213+
214+
return results
215+
216+
217+
def _update_padded_sequence(
218+
padded_sequence: torch.Tensor,
219+
x_recv: torch.Tensor,
220+
res,
221+
prompt_lengths: List[int],
222+
) -> None:
223+
for i in range(len(prompt_lengths)):
224+
prompt_lengths[i] += 1
225+
padded_sequence[i, prompt_lengths[i] - 1] = x_recv
226+
227+
139228
def _cleanup():
140229
dist.barrier()
141230
dist.destroy_process_group()
@@ -180,6 +269,17 @@ def main(args):
180269
pp_mesh = device_mesh["pp"]
181270
tp_rank = tp_mesh.get_local_rank()
182271
pp_rank = pp_mesh.get_local_rank()
272+
tp_group = tp_mesh.get_group()
273+
pp_group = pp_mesh.get_group()
274+
275+
logger.info(f"review: {pp_group=}, {tp_group= }")
276+
277+
logger.info(f"Created device mesh: {device_mesh}\n {tp_mesh=}, {pp_mesh=}\n")
278+
# TODO - this assumes 1D mesh, need to update for 2D+ mesh
279+
pp_group_size = pp_mesh.size()
280+
tp_group_size = tp_mesh.size()
281+
282+
logger.info(f"pp_group_size: {pp_group_size}, tp_group_size: {tp_group_size}")
183283

184284
# Assuming same number of GPUs per node
185285
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
@@ -198,9 +298,10 @@ def main(args):
198298
if rank == 0:
199299
logger.info(f"Model: {model}")
200300

201-
mbs = 2 # number of micro-batches
301+
mbs = 1 # number of micro-batches
202302
mb_size = 1 # micro-batch size
203303
batch_size = mbs * mb_size # total batch size
304+
204305
seqlen = 4096 # sequence length
205306
dim = 4096 # embedding dimension
206307
assert seqlen % sp_degree == 0
@@ -213,8 +314,10 @@ def main(args):
213314

214315
# Load weights
215316
logger.info(f"Loading weights for {pp_rank=} on {device=}")
216-
with TrackTime("cuda") as timer:
217-
_load_model_weights(model, distribution, device=device, model_config=config)
317+
318+
with CUDATrackTime() as timer:
319+
_load_model_weights(model, hf_model_name, device=device, model_config=config)
320+
218321
logger.info(
219322
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
220323
)
@@ -226,9 +329,8 @@ def main(args):
226329
logger.info(
227330
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n"
228331
)
229-
230-
# Setup input position
231-
# input_pos for prefill: a list of increasing integers from 0 to seqlen
332+
333+
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
232334
input_pos = torch.arange(seqlen, device=device)
233335
model.setup_input_pos(input_pos)
234336
model.eval()
@@ -249,41 +351,129 @@ def main(args):
249351
if len(cpu_tensors) > 0:
250352
raise ValueError("Found cpu tensors in stage")
251353

252-
# TODO: this can likely be removed after we prove out a few more models
253-
# verify dtypes for model - expect all to be model_dtype except for bool causal_mask atm.
254-
# dtype_count, dtype_locations, fp32_locations = record_module_dtypes(stage.submod)
255-
# logger.info(
256-
# f"Stage Dtypes - Found {len(dtype_count)} dtypes: {dtype_count.items()}"
257-
# )
258-
# assert (
259-
# len(dtype_count) == 2
260-
# ), f"Expected 2 dtypes in model after checkpoint loading: {model_dtype} and {torch.bool}"
354+
prompt = [
355+
"What is snow?",
356+
]
261357

262-
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
263-
logger.info(f"Input: {input_ids.dtype=}, {input_ids.shape=}, {input_ids.device=}")
358+
"""
359+
"What is the capital of France?",
360+
"What is your name?",
361+
"What is the capital of Japan?",
362+
"When is Christmas?",
363+
"Where does Santa Claus live?",
364+
"What is the capital of the United States?",
365+
"What is the capital of China?",
366+
"What is the capital of Russia?",
367+
"What is PyTorch?",
368+
"What is the capital of India?",
369+
"What is an LLM?",
370+
"What is the capital of Brazil?",
371+
"What is the capital of Mexico?",
372+
"What is the capital of Argentina?",
373+
"What is the capital of Canada?",
374+
]
375+
"""
264376

265-
schedule = ScheduleGPipe(stage, mbs)
266-
logger.info(f"Created schedule: {schedule}")
267377

268-
with torch.no_grad(): # .inference_mode():
269-
if pp_rank == 0:
270-
output = schedule.step(input_ids)
271-
else:
272-
output = schedule.step()
378+
start_pos = 0
273379

274-
if pp_rank == pp_degree - 1 and tp_rank == 0:
275-
logger.info(f"Output: {output}")
380+
# encode the prompt
381+
input_ids = _encode_strings(
382+
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
383+
)
384+
logger.info(f"{input_ids[0:8]=}")
276385

277-
# show peak memory stats for this stage
278-
res_mem_gib, res_mem_pct = gpu_memory_monitor.get_peak_stats()
279-
logger.info(
280-
f"{color.blue} Memory used: {color.green}{res_mem_pct:.3f} %, {color.magenta}{res_mem_gib:.3f} GB{color.reset}"
386+
# create a padded tensor for the input prompt
387+
padded_sequence, prompt_lengths = _create_padded_prompts(
388+
input_ids, tokenizer, seqlen, start_pos, device
281389
)
390+
logger.info(f"{prompt_lengths=}")
391+
logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}")
392+
if len(prompt_lengths) > 1:
393+
logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}")
394+
395+
schedule = ScheduleGPipe(stage, mbs)
396+
logger.info(f"Created schedule: {schedule}")
397+
398+
# with CUDATrackTime() as timer:
399+
first_pp_group = 0
400+
last_pp_group = pp_group_size - 1
401+
402+
x_recv = torch.zeros(1, device=device, dtype=torch.int64)
403+
logger.info(f"{x_recv.shape=}")
404+
405+
last_global_rank = world_size - 1
406+
res = []
407+
dst = None
408+
src = None
409+
410+
if pp_rank == last_pp_group:
411+
dst = dist.get_global_rank(pp_group, 0)
412+
elif pp_rank == 0:
413+
src = dist.get_global_rank(pp_group, last_pp_group)
414+
415+
# Decoding
416+
num_tokens = 40
417+
418+
with torch.no_grad():
419+
for step in range(num_tokens):
420+
# first
421+
if pp_rank == 0:
422+
schedule.step(padded_sequence)
423+
# only receive if not last step
424+
if step < num_tokens - 1:
425+
dist.recv(
426+
x_recv,
427+
src,
428+
group=pp_group,
429+
)
430+
_update_padded_sequence(
431+
padded_sequence, x_recv, res, prompt_lengths
432+
)
433+
434+
# last
435+
elif pp_rank == last_pp_group:
436+
output = schedule.step()
437+
# need to decode the output
438+
decode_results = _batch_decode_next_tokens(
439+
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
440+
)
441+
if tp_rank == 0:
442+
logger.info(
443+
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
444+
)
445+
446+
next_token = torch.tensor([decode_results[0][0]], device=device)
447+
res.append(decode_results[0][1])
448+
449+
# increment prompt lengths for next token
450+
for i in range(len(prompt_lengths)):
451+
prompt_lengths[i] += 1
452+
# logger.info(
453+
# f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
454+
# )
455+
456+
# only send if not last step
457+
if step < (num_tokens - 1):
458+
dist.send(
459+
next_token,
460+
dst,
461+
pp_group,
462+
)
463+
464+
# middle pp ranks
465+
else:
466+
schedule.step()
467+
468+
# output formatted response via last pp group and tp rank 0
469+
if pp_rank == last_pp_group and tp_rank == 0:
470+
logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}")
471+
formatted_response = "".join(res)
472+
logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$")
282473

283474
logger.info(
284475
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
285476
)
286-
287477
_cleanup()
288478

289479

distributed/force_download.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from transformers import AutoTokenizer, AutoModelForCausalLM
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
22

3-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
4-
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
5-
print("Model weights downloaded")
3+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
4+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
5+
print("Model weights and tokenizer downloaded")

run_dist.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export CUDA_VISIBLE_DEVICES=4,5,6,7
1+
#export CUDA_VISIBLE_DEVICES=4,5,6,7
22
PORT=${1:-29501}
33
NGPU=${NGPU:-"4"}
44
LOG_RANK=${LOG_RANK:-0,1,2,3}

0 commit comments

Comments
 (0)