Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 28598e7

Browse files
committed
add encode string and create_padded_input functions
1 parent 5986ed2 commit 28598e7

File tree

2 files changed

+410
-17
lines changed

2 files changed

+410
-17
lines changed

dist_back.py

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from pathlib import Path
9+
from types import SimpleNamespace
10+
from typing import Any, Dict
11+
12+
# Run command:
13+
# torchrun --nproc-per-node 4 dist_run.py
14+
import torch
15+
import torch.distributed as dist
16+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
17+
18+
from distributed.logging_utils import setup_logging
19+
# TODO - these are not distributed specific, consider moving to new package
20+
from distributed.safetensor_utils import (get_hf_config_file,
21+
get_hf_weight_map_and_path,
22+
load_safetensor_weights)
23+
from distributed.utils import Color as color, get_stage_size, build_gpu_memory_monitor, TrackTime, get_num_params
24+
from distributed.verification_utils import find_cpu_tensors
25+
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
26+
from torchchat.model import ModelArgs, Transformer
27+
from torchchat.utils.build_utils import set_precision
28+
29+
try:
30+
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
31+
except ImportError:
32+
TiktokenTokenizer = None
33+
try:
34+
from sentencepiece import SentencePieceProcessor
35+
except ImportError:
36+
SentencePieceProcessor = None
37+
38+
39+
# logger = setup_logging(__name__)
40+
from distributed.logging_utils import SingletonLogger
41+
logger = SingletonLogger.get_logger(__name__)
42+
43+
44+
NAME_TO_HF_MODEL_ID_AND_DTYPE = {
45+
"Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
46+
"Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
47+
}
48+
CACHE_PRECISION = torch.bfloat16
49+
50+
51+
def _init_distributed():
52+
dist.init_process_group("nccl")
53+
rank = dist.get_rank()
54+
world_size = dist.get_world_size()
55+
# Assuming same number of GPUs per node
56+
torch.cuda.set_device(rank % torch.cuda.device_count())
57+
return rank, world_size
58+
59+
60+
def _create_device_mesh(mesh_dimensions):
61+
return dist.init_device_mesh("cuda", mesh_dimensions, mesh_dim_names=("pp", "tp"))
62+
63+
64+
def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
65+
return SimpleNamespace(**dictionary)
66+
67+
68+
def _build_chat_tokenizer(
69+
model_base_name: str = "llama3",
70+
) -> SentencePieceProcessor | TiktokenTokenizer:
71+
# Create base args for tokenizer
72+
default_model_dir = Path(
73+
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
74+
).expanduser()
75+
76+
tokenconfig = {
77+
"model_directory": default_model_dir,
78+
"model": model_base_name,
79+
"tokenizer_path": None,
80+
}
81+
args = dict_to_args(tokenconfig)
82+
tokenizer_args = TokenizerArgs.from_args(args)
83+
tokenizer = _initialize_tokenizer(tokenizer_args)
84+
assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}"
85+
logger.info(
86+
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
87+
)
88+
return tokenizer
89+
90+
def _encode_string(string, tokenizer, bos=True, device="cuda", dtype=torch.int64):
91+
tokens = tokenizer.encode(string)
92+
if bos:
93+
tokens = [tokenizer.bos_id()] + tokens
94+
logger.info(f"***** encoding: {tokens=}, {string=}")
95+
return torch.tensor(tokens, dtype=dtype, device=device)
96+
97+
def _logits_to_probs(
98+
logits,
99+
temperature=1.0,
100+
):
101+
logits = logits / max(
102+
temperature, 1e-5 if logits.dtype != torch.float16 else 1e-3
103+
)
104+
probs = torch.nn.functional.softmax(logits, dim=-1)
105+
return probs
106+
107+
def _load_model_weights(stage_module, hf_model_name, device, model_config):
108+
"""Load the weights from the safetensor file(s) into the model stage.
109+
Model config is needed b/c we permute wq and wk weights based on attn heads.
110+
"""
111+
112+
weight_map, weight_path, key_map = get_hf_weight_map_and_path(hf_model_name)
113+
114+
num_loaded_weights, num_missing_weights = load_safetensor_weights(
115+
stage_module,
116+
weight_map,
117+
weight_path,
118+
key_map,
119+
device,
120+
model_config=model_config,
121+
)
122+
logger.info(
123+
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
124+
)
125+
if num_missing_weights > 0:
126+
raise ValueError(f"Missing {num_missing_weights} weights")
127+
128+
def _multinomial_sample_one_no_sync(
129+
probs_sort,
130+
): # Does multinomial sampling without a cuda synchronization
131+
q = torch.empty_like(probs_sort).exponential_(1)
132+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
133+
134+
def _cleanup():
135+
dist.barrier()
136+
dist.destroy_process_group()
137+
138+
def _get_hf_tokenizer(hf_model_name):
139+
"""Load tokenizer from HF model id. note - use torchchat tokenizer as default"""
140+
from transformers import AutoTokenizer
141+
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
142+
assert tokenizer is not None, f"Failed to load tokenizer for {hf_model_name}"
143+
logger.info(f"Loaded tokenizer for {hf_model_name}")
144+
tokenizer.pad_token = tokenizer.eos_token
145+
return tokenizer
146+
147+
148+
def main():
149+
rank, world_size = _init_distributed()
150+
gpu_memory_monitor, device_info = build_gpu_memory_monitor()
151+
logger.info(f"{color.yellow} {device_info}{color.reset}")
152+
153+
154+
MODEL_NAME = "Meta-Llama-3-8B" # "Transformer-2-7b-chat-hf"
155+
156+
config = ModelArgs.from_name(MODEL_NAME).text_transformer_args
157+
logger.info(f"Chat Model Config: {config}")
158+
159+
160+
tokenizer = _build_chat_tokenizer()
161+
logger.info(f"built tokenizer {tokenizer=}")
162+
163+
hf_model_name, model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE[MODEL_NAME]
164+
logger.info(f"Using HF model weights from {hf_model_name} and dtype {model_dtype}")
165+
166+
167+
hf_tokenizer = _get_hf_tokenizer(hf_model_name)
168+
169+
set_precision(CACHE_PRECISION)
170+
logger.info(f"Using cache precision {CACHE_PRECISION}")
171+
172+
hf_config = get_hf_config_file(hf_model_name)
173+
if hf_config is None:
174+
raise ValueError(f"Config file not found for model id {hf_model_name}")
175+
logger.info(f"Using HF model weights from {hf_model_name}")
176+
177+
# Assuming 2 pipeline stages, feel free to change this as long as the
178+
# asserts are satisfied
179+
pp_degree = 4
180+
assert world_size % pp_degree == 0
181+
assert config.n_layers % pp_degree == 0
182+
183+
# Sequence parallel is enabled in this program
184+
# Sequence parallel = Tensor parallel + dividing sequence by tp_degree at layer boundary
185+
sp_degree = world_size // pp_degree
186+
187+
# Create device mesh
188+
mesh_dimensions = (pp_degree, sp_degree)
189+
device_mesh = _create_device_mesh(mesh_dimensions)
190+
tp_mesh = device_mesh["tp"]
191+
pp_mesh = device_mesh["pp"]
192+
tp_rank = tp_mesh.get_local_rank()
193+
pp_rank = pp_mesh.get_local_rank()
194+
195+
# Assuming same number of GPUs per node
196+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
197+
198+
# Fill in PP configs
199+
config.stage_idx = pp_rank
200+
config.n_stages = pp_degree
201+
202+
with device:
203+
model = Transformer(config)
204+
205+
model.setup_caches(1, 4096)
206+
207+
# Distribute model on TP mesh
208+
model.distribute(tp_mesh)
209+
logger.info(f"Model: {model}")
210+
211+
mbs = 1 # number of micro-batches
212+
mb_size = 1 # micro-batch size
213+
batch_size = mbs * mb_size # total batch size
214+
seqlen = 4096 # sequence length
215+
dim = 4096 # embedding dimension
216+
assert seqlen % sp_degree == 0
217+
218+
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
219+
activation = torch.rand(
220+
mb_size, seqlen // sp_degree, dim, device=device, dtype=model_dtype
221+
)
222+
example_args = mb_ids if pp_rank == 0 else activation
223+
224+
# Load weights
225+
with TrackTime() as timer:
226+
logger.info(f"Loading weights for {pp_rank=} on {device=}")
227+
_load_model_weights(model, hf_model_name, device=device, model_config=config)
228+
logger.info(f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for {rank}{color.reset}")
229+
230+
# info on stage size and params
231+
stage_size, stage_size_formatted = get_stage_size(model)
232+
stage_num_params = get_num_params(model)
233+
logger.info(f"Stage rank {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n")
234+
235+
# Setup input position
236+
# input_pos for prefill: a list of increasing integers from 0 to seqlen
237+
input_pos = torch.arange(seqlen, device=device)
238+
model.setup_input_pos(input_pos)
239+
model.eval()
240+
241+
logger.info(f"Creating pipeline stage {pp_rank=}, {pp_degree=}")
242+
stage = PipelineStage(
243+
model,
244+
pp_rank,
245+
pp_degree,
246+
device,
247+
input_args=(example_args,),
248+
group=pp_mesh.get_group(),
249+
)
250+
251+
# this check confirms that there are no cpu tensors in the model..we expect this to be true.
252+
cpu_tensors = find_cpu_tensors(stage.submod)
253+
# logger.info(f"Found {len(cpu_tensors)} cpu tensors: {cpu_tensors}")
254+
if len(cpu_tensors) > 0:
255+
raise ValueError("Found cpu tensors in stage")
256+
257+
# TODO: this can likely be removed after we prove out a few more models
258+
# verify dtypes for model - expect all to be model_dtype except for bool causal_mask atm.
259+
# dtype_count, dtype_locations, fp32_locations = record_module_dtypes(stage.submod)
260+
# logger.info(
261+
# f"Stage Dtypes - Found {len(dtype_count)} dtypes: {dtype_count.items()}"
262+
# )
263+
# assert (
264+
# len(dtype_count) == 2
265+
# ), f"Expected 2 dtypes in model after checkpoint loading: {model_dtype} and {torch.bool}"
266+
267+
#input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
268+
#logger.info(f"Input: {input_ids.dtype=}, {input_ids.shape=}, {input_ids.device=}")
269+
270+
prompt = "what is snow?"
271+
input_ids = _encode_string(prompt, tokenizer, device, dtype=torch.int64)
272+
273+
prompt_len = input_ids.size(0)
274+
start_pos = 0
275+
276+
# create a padded tensor for the input prompt
277+
max_new_tokens = min(seqlen, seqlen - start_pos - prompt_len)
278+
token_buffer_size = prompt_len + max_new_tokens
279+
280+
seq = torch.full((1, token_buffer_size), tokenizer.eos_id(), dtype=torch.int64, device=device)
281+
seq[0, :prompt_len] = input_ids
282+
283+
284+
schedule = ScheduleGPipe(stage, mbs)
285+
logger.info(f"Created schedule: {schedule}")
286+
287+
with torch.no_grad(): # .inference_mode():
288+
if pp_rank == 0:
289+
schedule.step(seq)
290+
else:
291+
output = schedule.step()
292+
293+
# Decoding
294+
if pp_rank == pp_degree - 1 and tp_rank == 0:
295+
296+
next_token_logits = output[:,prompt_len-1, :]
297+
298+
logger.info(f"{next_token_logits=}")
299+
logger.info(f"{next_token_logits.shape=}")
300+
301+
next_token = torch.argmax(next_token_logits, dim=-1)
302+
303+
# self.tokenizer.decode([period_id] + x.tolist())[1:]
304+
next_token_decoded = tokenizer.decode((next_token.tolist()))
305+
306+
logger.info(f"\n\n{color.green}====>>>> {color.blue} {next_token_decoded=}, {next_token}\n{color.reset}")
307+
res_mem_gib, res_mem_pct = gpu_memory_monitor.get_peak_stats()
308+
logger.info(f"{color.blue} Memory used: {color.green}{res_mem_pct:.3f} %, {color.magenta}{res_mem_gib:.3f} GB{color.reset}")
309+
310+
311+
312+
if pp_rank == pp_degree - 1 and tp_rank == 0:
313+
response = []
314+
response.append(next_token_decoded)
315+
316+
token_array = [19435, 374, 16054]
317+
for i in range(len(token_array)):
318+
prompt_len += 1
319+
newest_token = token_array[i]
320+
next_token_tensor = torch.tensor([newest_token], dtype=torch.int64, device=device)
321+
seq[0, prompt_len] = next_token_tensor
322+
if pp_rank == pp_degree - 1 and tp_rank == 0:
323+
next_token_decoded = tokenizer.decode(next_token_tensor.tolist())
324+
response.append(next_token_decoded)
325+
326+
pretend_next_token = token_array[0]
327+
if pretend_next_token != tokenizer.eos_id():
328+
329+
for i in range(1):
330+
logger.info(f"running loop decoding, iter {i}")
331+
prompt_len += 1
332+
newest_token = token_array[i]
333+
next_token_tensor = torch.tensor(newest_token, dtype=torch.int64, device=device)
334+
seq[0, prompt_len] = next_token_tensor
335+
336+
with torch.no_grad(): # .inference_mode():
337+
if pp_rank == 0:
338+
schedule.step(seq)
339+
else:
340+
output = schedule.step()
341+
342+
if pp_rank == pp_degree - 1 and tp_rank == 0:
343+
next_token_logits = output[:,prompt_len-1, :]
344+
345+
logger.info(f"{next_token_logits=}")
346+
logger.info(f"{next_token_logits.shape=}")
347+
next_token = torch.argmax(next_token_logits, dim=-1)
348+
349+
# self.tokenizer.decode([period_id] + x.tolist())[1:]
350+
next_token_decoded = tokenizer.decode((next_token.tolist()))
351+
logger.info(f"\n\n{color.green}====>>>> {color.blue} {next_token_decoded=}, {next_token}\n{color.reset}")
352+
353+
seq[0, prompt_len] = next_token
354+
logger.info(f"{seq= }")
355+
response.append(next_token_decoded)
356+
357+
logger.info(f"\n\n{color.green} After {i=} iters ====>>>> {color.blue} {response}\n{color.reset}")
358+
359+
360+
361+
362+
363+
logger.info(
364+
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
365+
)
366+
367+
_cleanup()
368+
369+
370+
if __name__ == "__main__":
371+
main()

0 commit comments

Comments
 (0)