|
7 | 7 | import os |
8 | 8 | from pathlib import Path |
9 | 9 | from types import SimpleNamespace |
10 | | -from typing import Any, Dict |
| 10 | +from typing import Any, Dict, Tuple |
11 | 11 |
|
12 | 12 | # Run command: |
13 | 13 | # torchrun --nproc-per-node 4 dist_run.py |
@@ -124,19 +124,19 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config): |
124 | 124 |
|
125 | 125 | def _encode_string(string, tokenizer, bos=True, device="cuda", dtype=torch.int64)-> torch.Tensor: |
126 | 126 | """Encode a prompt string into a tensor of token ids.""" |
127 | | - tokens = tokenizer.encode(string) |
128 | | - if bos: |
129 | | - tokens = [tokenizer.bos_id()] + tokens |
130 | | - return torch.tensor(tokens, dtype=dtype, device=device) |
131 | | - |
132 | | -def _create_padded_prompt(input_ids, seqlen, start_pos, device) -> torch.Tensor: |
133 | | - """Create a padded tensor for the encoded input prompt.""" |
134 | | - prompt_len = input_ids.size(0) |
135 | | - max_new_tokens = min(seqlen, seqlen - start_pos - prompt_len) |
136 | | - token_buffer_size = prompt_len + max_new_tokens |
137 | | - seq = torch.full((1, token_buffer_size), tokenizer.eos_id(), dtype=torch.int64, device=device) |
138 | | - seq[0, :prompt_len] = input_ids |
139 | | - return seq |
| 127 | + tokens = tokenizer.encode(string) |
| 128 | + if bos: |
| 129 | + tokens = [tokenizer.bos_id()] + tokens |
| 130 | + return torch.tensor(tokens, dtype=dtype, device=device) |
| 131 | + |
| 132 | +def _create_padded_prompt(input_ids, tokenizer, seqlen, start_pos, device) -> Tuple[torch.Tensor, int]: |
| 133 | + """Create a padded tensor for the encoded input prompt. Returns the padded tensor and the prompt length.""" |
| 134 | + prompt_len = input_ids.size(0) |
| 135 | + max_new_tokens = min(seqlen, seqlen - start_pos - prompt_len) |
| 136 | + token_buffer_size = prompt_len + max_new_tokens |
| 137 | + seq = torch.full((1, token_buffer_size), tokenizer.eos_id(), dtype=torch.int64, device=device) |
| 138 | + seq[0, :prompt_len] = input_ids |
| 139 | + return seq, prompt_len |
140 | 140 |
|
141 | 141 | def _cleanup(): |
142 | 142 | dist.barrier() |
@@ -251,32 +251,30 @@ def main(): |
251 | 251 | if len(cpu_tensors) > 0: |
252 | 252 | raise ValueError("Found cpu tensors in stage") |
253 | 253 |
|
254 | | - |
255 | | - prompt = "What is the capital of France?" |
| 254 | + prompt = "What is snow?" |
256 | 255 | start_pos = 0 |
257 | 256 |
|
258 | 257 | # encode the prompt |
259 | 258 | input_ids = _encode_string(prompt, tokenizer, bos=True, device=device, dtype=torch.int64) |
260 | | - |
261 | | - # create a padded tensor for the input prompt |
262 | | - seq = _create_padded_prompt(input_ids, seqlen, start_pos, device) |
| 259 | + logger.info(f"{input_ids[0:8]=}") |
263 | 260 |
|
| 261 | + # create a padded tensor for the input prompt |
| 262 | + padded_sequence, prompt_len = _create_padded_prompt(input_ids, tokenizer, seqlen, start_pos, device) |
| 263 | + logger.info(f"{prompt_len=}") |
264 | 264 |
|
265 | 265 | schedule = ScheduleGPipe(stage, mbs) |
266 | 266 | logger.info(f"Created schedule: {schedule}") |
267 | 267 |
|
268 | 268 | with torch.no_grad(): # .inference_mode(): |
269 | 269 | if pp_rank == 0: |
270 | | - schedule.step(seq) |
| 270 | + schedule.step(padded_sequence) |
271 | 271 | else: |
272 | 272 | output = schedule.step() |
273 | 273 |
|
274 | 274 | # Decoding |
275 | 275 | if pp_rank == pp_degree - 1 and tp_rank == 0: |
276 | | - |
277 | 276 | next_token_logits = output[:,prompt_len-1, :] |
278 | 277 | next_token = torch.argmax(next_token_logits, dim=-1) |
279 | | - |
280 | 278 | next_token_decoded = tokenizer.decode((next_token.tolist())) |
281 | 279 |
|
282 | 280 | logger.info(f"\n\n{color.green} Prefill response ====>>>> {color.blue} {next_token_decoded=}, {next_token}\n{color.reset}") |
|
0 commit comments