77import os
88from pathlib import Path
99from types import SimpleNamespace
10- from typing import Any , Dict , Tuple
10+ from typing import Any , Dict , List , Optional , Tuple
1111
1212# Run command:
1313# torchrun --nproc-per-node 4 dist_run.py
1414import torch
1515import torch .distributed as dist
1616from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
1717
18-
1918from distributed .logging_utils import SingletonLogger
20-
2119# TODO - these are not distributed specific, consider moving to new package
22- from distributed .safetensor_utils import (
23- get_hf_config_file ,
24- get_hf_weight_map_and_path ,
25- load_safetensor_weights ,
26- )
27-
28- from distributed .utils import (
29- Color as color ,
30- GPUMemoryMonitor ,
31- get_module_size ,
32- get_num_params ,
33- bytes_to_readable ,
34- TrackTime ,
35- CUDATrackTime ,
36- )
37-
20+ from distributed .safetensor_utils import (get_hf_config_file ,
21+ get_hf_weight_map_and_path ,
22+ load_safetensor_weights )
23+ from distributed .utils import Color as color
24+ from distributed .utils import (GPUMemoryMonitor , TrackTime ,
25+ bytes_to_readable , get_module_size ,
26+ get_num_params )
3827from distributed .verification_utils import find_cpu_tensors
3928from torchchat .cli .builder import TokenizerArgs , _initialize_tokenizer
4029from torchchat .model import ModelArgs , Transformer
@@ -123,28 +112,86 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config):
123112 raise ValueError (f"Missing { num_missing_weights } weights" )
124113
125114
126- def _encode_string (
127- string : str , tokenizer , bos : bool = True , device : str = "cuda" , dtype = torch .int64
128- ) -> torch .Tensor :
129- """Encode a prompt string into a tensor of token ids."""
130- tokens = tokenizer .encode (string )
131- if bos :
132- tokens = [tokenizer .bos_id ()] + tokens
133- return torch .tensor (tokens , dtype = dtype , device = device )
134-
135-
136- def _create_padded_prompt (
137- input_ids : torch .Tensor , tokenizer , seqlen : int , start_pos : int , device : str
138- ) -> Tuple [torch .Tensor , int ]:
139- """Create a padded tensor for the encoded input prompt. Returns the padded tensor and the prompt length."""
140- prompt_len = input_ids .size (0 )
141- max_new_tokens = min (seqlen , seqlen - start_pos - prompt_len )
142- token_buffer_size = prompt_len + max_new_tokens
143- seq = torch .full (
144- (1 , token_buffer_size ), tokenizer .eos_id (), dtype = torch .int64 , device = device
115+ def _encode_strings (
116+ strings : List [str ],
117+ tokenizer ,
118+ bos : bool = True ,
119+ device : str = "cuda" ,
120+ dtype = torch .int64 ,
121+ ) -> List [torch .Tensor ]:
122+ """Encode a list of prompt strings into a list of tensor token ids."""
123+ encoded_list = []
124+ for string in strings :
125+ tokens = tokenizer .encode (string )
126+ if bos :
127+ tokens = [tokenizer .bos_id ()] + tokens
128+ encoded_list .append (torch .tensor (tokens , dtype = dtype , device = device ))
129+ return encoded_list
130+
131+
132+ def _create_padded_prompts (
133+ input_ids_list : List [torch .Tensor ],
134+ tokenizer ,
135+ seqlen : int ,
136+ start_pos : int ,
137+ device : str ,
138+ pad_token_id : Optional [int ] = None ,
139+ ) -> Tuple [torch .Tensor , List [int ]]:
140+ """
141+ Create a padded tensor for multiple encoded input prompts.
142+
143+ Returns:
144+ Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths.
145+ """
146+ pad_token_id = pad_token_id if pad_token_id is not None else tokenizer .eos_id ()
147+
148+ # Find the maximum prompt length
149+ max_prompt_len = max (ids .size (0 ) for ids in input_ids_list )
150+
151+ # Calculate the buffer size
152+ max_new_tokens = max (0 , min (seqlen - start_pos , seqlen - max_prompt_len ))
153+ token_buffer_size = max_prompt_len + max_new_tokens
154+
155+ # Create the padded batch tensor
156+ batch_size = len (input_ids_list )
157+ batch_seq = torch .full (
158+ (batch_size , token_buffer_size ), pad_token_id , dtype = torch .int64 , device = device
145159 )
146- seq [0 , :prompt_len ] = input_ids
147- return seq , prompt_len
160+
161+ prompt_lengths = []
162+ for i , input_ids in enumerate (input_ids_list ):
163+ prompt_len = input_ids .size (0 )
164+ batch_seq [i , :prompt_len ] = input_ids
165+ prompt_lengths .append (prompt_len )
166+
167+ return batch_seq , prompt_lengths
168+
169+
170+ def _batch_decode_next_tokens (
171+ output : torch .Tensor ,
172+ prompt_lengths : List [int ],
173+ tokenizer ,
174+ ) -> List [Tuple [int , str ]]:
175+ """
176+ Decode the next token for each prompt in the batch.
177+
178+ Returns:
179+ List[Tuple[int, str]]: List of tuples containing the next token id and its
180+ decoded string for each prompt in the batch.
181+ """
182+ batch_size = output .shape [0 ]
183+ results = []
184+
185+ for i in range (batch_size ):
186+ next_token_logits = output [i , prompt_lengths [i ] - 1 , :]
187+
188+ # Argmax (deterministic) TODO: add temperature
189+ next_token = torch .argmax (next_token_logits , dim = - 1 )
190+
191+ next_token_decoded = tokenizer .decode ([next_token .item ()])
192+ results .append ((next_token .item (), next_token_decoded ))
193+
194+ return results
148195
149196
150197def _cleanup ():
@@ -209,7 +256,7 @@ def main():
209256 model .distribute (tp_mesh )
210257 # logger.info(f"Model: {model}")
211258
212- mbs = 1 # number of micro-batches TODO: move to multibatch
259+ mbs = 4 # number of micro-batches
213260 mb_size = 1 # micro-batch size
214261 batch_size = mbs * mb_size # total batch size
215262
@@ -260,21 +307,28 @@ def main():
260307 if len (cpu_tensors ) > 0 :
261308 raise ValueError ("Found cpu tensors in stage" )
262309
263- prompt = "What is the capital of France?"
310+ prompt = [
311+ "What is the capital of France?" ,
312+ "What is snow?" ,
313+ "What is your name?" ,
314+ "What is the capital of Japan?" ,
315+ ]
264316 start_pos = 0
265317
266318 # encode the prompt
267- input_ids = _encode_string (
319+ input_ids = _encode_strings (
268320 prompt , tokenizer , bos = True , device = device , dtype = torch .int64
269321 )
270322 logger .info (f"{ input_ids [0 :8 ]= } " )
271323
272324 # create a padded tensor for the input prompt
273- padded_sequence , prompt_len = _create_padded_prompt (
325+ padded_sequence , prompt_lengths = _create_padded_prompts (
274326 input_ids , tokenizer , seqlen , start_pos , device
275327 )
276- logger .info (f"{ prompt_len = } " )
277- logger .info (f"{ padded_sequence [0 , :prompt_len + 1 ]= } " )
328+ logger .info (f"{ prompt_lengths = } " )
329+ logger .info (f"first prompt { padded_sequence [0 , :prompt_lengths [0 ]+ 1 ]= } " )
330+ if len (prompt_lengths ) > 1 :
331+ logger .info (f"second prompt { padded_sequence [1 , :prompt_lengths [1 ]+ 1 ]= } " )
278332
279333 schedule = ScheduleGPipe (stage , mbs )
280334 logger .info (f"Created schedule: { schedule } " )
@@ -287,12 +341,12 @@ def main():
287341
288342 # Decoding
289343 if pp_rank == pp_degree - 1 and tp_rank == 0 :
290- next_token_logits = output [:, prompt_len - 1 , :]
291- next_token = torch . argmax ( next_token_logits , dim = - 1 )
292- next_token_decoded = tokenizer . decode (( next_token . tolist ()) )
344+ decode_results = _batch_decode_next_tokens (
345+ output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
346+ )
293347
294348 logger .info (
295- f"\n \n { color .green } Prefill response ====>>>> { color .blue } { next_token_decoded = } , { next_token } \n { color .reset } "
349+ f"\n \n { color .green } Prefill responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
296350 )
297351
298352 # show peak memory stats for this stage
0 commit comments