1212import os
1313from enum import auto , Enum
1414from pathlib import Path
15- from types import SimpleNamespace
15+ from types import SimpleNamespace , MethodType
1616from typing import Any , Dict , List , Optional , Tuple
1717
1818import torch
1919import torch .distributed as dist
2020from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
21- from torchchat .cli .builder import _initialize_tokenizer , TokenizerArgs
21+ from torchchat .cli .builder import TokenizerArgs
2222
2323# TODO - these are not distributed specific, consider moving to new package
2424from torchchat .distributed .checkpoint_utils import (
5050
5151
5252logger = SingletonLogger .get_logger ()
53- _tokenizer_type = None # global variable to store the tokenizer type
5453
5554# Using model name to identify the model to load, for example "llama2-7b-chat".
5655# You can change it to other values listed below.
6160}
6261
6362
64- class TokenizerType (Enum ):
65- Tiktoken = auto ()
66- SentencePiece = auto ()
67-
68-
6963def _init_distributed ():
7064 dist .init_process_group ("nccl" )
7165 rank = dist .get_rank ()
@@ -82,14 +76,29 @@ def _create_device_mesh(mesh_dimensions):
8276def dict_to_args (dictionary : Dict [str , Any ]) -> SimpleNamespace :
8377 return SimpleNamespace (** dictionary )
8478
79+ def _patch_tokenizer (tokenizer ):
80+ """Patch the tokenizer to support decoding of token ids."""
81+ if isinstance (tokenizer , TiktokenTokenizer ):
82+ # Patch tiktokenizer to allow a list of sequences.
83+ #TODO: Upstream to tokenizer modules
84+ old_decode = tokenizer .decode
85+
86+ def decode (self , token_ids : List [int | List [int ]], * args , ** kwargs ) -> str | List [str ]:
87+ if len (token_ids )< 1 :
88+ return ""
89+ if isinstance (token_ids [0 ], list ):
90+ return [old_decode (t , * args , ** kwargs ) for t in token_ids ]
91+ else :
92+ return old_decode (token_ids , * args , ** kwargs )
93+
94+ tokenizer .decode = MethodType (decode , tokenizer )
95+ return tokenizer
8596
8697def _build_chat_tokenizer (
8798 model_name : str ,
8899 model_base_name : Optional [str ] = None ,
89100) -> SentencePieceProcessor | TiktokenTokenizer :
90- """Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91-
92- global _tokenizer_type
101+ """Builds a tokenizer for the given model name"""
93102
94103 # Try to infer the model base name from the model name:
95104 # e.g. "llama2-7b-chat" -> "llama2"
@@ -112,20 +121,14 @@ def _build_chat_tokenizer(
112121 }
113122 args = dict_to_args (tokenconfig )
114123 tokenizer_args = TokenizerArgs .from_args (args )
115- tokenizer = _initialize_tokenizer ( tokenizer_args )
124+ tokenizer = tokenizer_args . t
116125 assert tokenizer is not None , f"Failed to get tokenizer using { tokenconfig = } "
117126 logger .info (
118127 f"using tokenizer = { tokenizer .__class__ .__module__ } .{ tokenizer .__class__ .__name__ } "
119128 )
120- # set global variable _tokenizer_type
121- if isinstance (tokenizer , TiktokenTokenizer ):
122- _tokenizer_type = TokenizerType .Tiktoken
123- elif isinstance (tokenizer , SentencePieceProcessor ):
124- _tokenizer_type = TokenizerType .SentencePiece
125- else :
126- raise ValueError (f"Unknown tokenizer type: { tokenizer .__class__ } " )
127129
128- logger .info (f"tokenizer type = { _tokenizer_type } " )
130+ tokenizer = _patch_tokenizer (tokenizer )
131+
129132 return tokenizer
130133
131134
@@ -568,15 +571,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
568571 # token ids. Thus cat'ing along dim 1.
569572 res = torch .cat (res , dim = 1 )
570573 res_list = res .tolist ()
571- if _tokenizer_type == TokenizerType .Tiktoken :
572- # For TiktokenTokenizer, we need to decode prompt by prompt.
573- # TODO: is there a better way to do this?
574- responses = [tokenizer .decode (sequence ) for sequence in res_list ]
575- elif _tokenizer_type == TokenizerType .SentencePiece : # SentencePieceProcessor
576- # For SentencePieceProcessor, we can decode the entire 2D list at once.
577- responses = tokenizer .decode (res_list )
578- else :
579- raise ValueError (f"Unknown tokenizer type { _tokenizer_type } " )
574+
575+ responses = tokenizer .decode (res_list )
580576
581577 # Show prompts and responses
582578 for prompt_text , response_text in zip (prompt , responses ):
0 commit comments