1212import os
1313from enum import auto , Enum
1414from pathlib import Path
15- from types import SimpleNamespace , MethodType
15+ from types import MethodType , SimpleNamespace
1616from typing import Any , Dict , List , Optional , Tuple
1717
1818import torch
@@ -71,21 +71,26 @@ def _init_distributed():
7171
7272
7373def _create_device_mesh (pp_degree , tp_degree ):
74- return dist .init_device_mesh ("cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" ))
74+ return dist .init_device_mesh (
75+ "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
76+ )
7577
7678
7779def dict_to_args (dictionary : Dict [str , Any ]) -> SimpleNamespace :
7880 return SimpleNamespace (** dictionary )
7981
82+
8083def _patch_tokenizer (tokenizer ):
8184 """Patch the tokenizer to support decoding of token ids."""
8285 if isinstance (tokenizer , TiktokenTokenizer ):
8386 # Patch tiktokenizer to allow a list of sequences.
84- #TODO: Upstream to tokenizer modules
87+ # TODO: Upstream to tokenizer modules
8588 old_decode = tokenizer .decode
8689
87- def decode (self , token_ids : List [int | List [int ]], * args , ** kwargs ) -> str | List [str ]:
88- if len (token_ids )< 1 :
90+ def decode (
91+ self , token_ids : List [int | List [int ]], * args , ** kwargs
92+ ) -> str | List [str ]:
93+ if len (token_ids ) < 1 :
8994 return ""
9095 if isinstance (token_ids [0 ], list ):
9196 return [old_decode (t , * args , ** kwargs ) for t in token_ids ]
@@ -95,6 +100,7 @@ def decode(self, token_ids: List[int|List[int]], *args, **kwargs) -> str | List[
95100 tokenizer .decode = MethodType (decode , tokenizer )
96101 return tokenizer
97102
103+
98104def _build_chat_tokenizer (
99105 model_name : str ,
100106 model_base_name : Optional [str ] = None ,
@@ -221,7 +227,7 @@ def _create_padded_prompts(
221227
222228def _batch_decode_next_tokens (
223229 output : torch .Tensor ,
224- pos : List [int ]= None ,
230+ pos : List [int ] = None ,
225231 temperature : float = 1.0 ,
226232 topk : int = 10 ,
227233) -> torch .Tensor :
@@ -388,7 +394,7 @@ def main(args, pipe):
388394 # Batch size. Since we push batches dynamically through the pipeline rather
389395 # than chunking them, this is effectively micro-batch size in pipeline
390396 # sense. Thus it is interchangeable with micro-batch size below.
391- batch_size = 1 # len(prompt)
397+ batch_size = 1 # len(prompt)
392398 seqlen_prefill = 1024 # sequence length
393399 dim = 4096 # embedding dimension
394400
@@ -463,7 +469,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
463469 raise ValueError (f"Unknown command: { command } " )
464470 else :
465471 prompt = command
466- assert len (prompt ) == batch_size , f"Expecting { batch_size = } prompts but got { len (prompt )= } "
472+ assert (
473+ len (prompt ) == batch_size
474+ ), f"Expecting { batch_size = } prompts but got { len (prompt )= } "
467475 logger .info (f"{ color .green } Prompt: { prompt } { color .reset } " )
468476
469477 start_pos = 0
@@ -508,7 +516,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
508516 logger .info (f"{ color .green } Decoding...{ prompt_lengths = } { color .reset } " )
509517 new_token = _batch_decode_next_tokens (output , prompt_lengths )
510518 res .append (new_token )
511- #TODO: Move to a separate decoding thread
519+ # TODO: Move to a separate decoding thread
512520 resp = _decode_in_flight (new_token , tokenizer , tp_rank )
513521 pipe .send ((resp , new_token .tolist ()))
514522 else :
@@ -539,7 +547,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539547 command = pipe .recv ()
540548 assert isinstance (command , str )
541549 if command == "stop" :
542- break
550+ break
543551 elif command == "step" :
544552 pass
545553 else :
@@ -573,7 +581,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
573581 if pp_rank == last_pp_rank :
574582 new_token = _batch_decode_next_tokens (output )
575583 res .append (new_token )
576- #TODO: Move to a separate decoding thread
584+ # TODO: Move to a separate decoding thread
577585 resp = _decode_in_flight (new_token , tokenizer , tp_rank )
578586 pipe .send ((resp , new_token ))
579587 else :
@@ -602,7 +610,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
602610 for prompt_text , response_text in zip (prompt , responses ):
603611 logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
604612 logger .info (f"Response: { color .red } { response_text } { color .reset } " )
605-
613+
606614 # Cleanup
607615 _cleanup ()
608616 logger .info (
0 commit comments