@@ -45,26 +45,25 @@ class MistralTokenizer:
4545 def __init__ (self , tokenizer : PublicMistralTokenizer ) -> None :
4646 self .mistral = tokenizer
4747 self .instruct = tokenizer .instruct_tokenizer
48- self .tokenizer = tokenizer .instruct_tokenizer .tokenizer
4948
50- self .vocab_size = len (self .tokenizer .vocab ())
51-
52- assert isinstance (self .tokenizer ,
53- (Tekkenizer , SentencePieceTokenizer )), type (
54- self .tokenizer )
55-
56- if (is_tekken := isinstance (self .tokenizer , Tekkenizer )):
49+ tokenizer_ = tokenizer .instruct_tokenizer .tokenizer
50+ if isinstance (tokenizer_ , Tekkenizer ):
5751 # Make sure special tokens will not raise
58- self .tokenizer .special_token_policy = SpecialTokenPolicy .IGNORE
59-
60- self ._is_tekken = is_tekken
52+ tokenizer_ .special_token_policy = SpecialTokenPolicy .IGNORE
53+
54+ self ._vocab = {
55+ token : idx
56+ for idx , token in enumerate (tokenizer_ .vocab ())
57+ }
58+ elif isinstance (tokenizer_ , SentencePieceTokenizer ):
59+ self ._vocab = {
60+ token : idx
61+ for idx , token in enumerate (tokenizer_ .vocab ())
62+ }
63+ else :
64+ raise TypeError (f"Unsupported tokenizer: { type (tokenizer_ )} " )
6165
62- # the following attributes are set to fit VLLM's design
63- self .is_fast = True
64- self .chat_template = True
65- self .all_special_ids : List [Any ] = []
66- self .all_special_tokens : List [Any ] = []
67- self .all_special_tokens_extended : List [Any ] = []
66+ self .tokenizer = tokenizer_
6867
6968 @classmethod
7069 def from_pretrained (cls ,
@@ -102,6 +101,38 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
102101 revision = revision )
103102 return tokenizer_file
104103
104+ # the following attributes are set to fit VLLM's design
105+ @property
106+ def all_special_tokens_extended (self ) -> List [str ]:
107+ return []
108+
109+ @property
110+ def all_special_tokens (self ) -> List [str ]:
111+ return []
112+
113+ @property
114+ def all_special_ids (self ) -> List [int ]:
115+ return []
116+
117+ @property
118+ def bos_token_id (self ) -> int :
119+ return self .tokenizer .bos_id
120+
121+ @property
122+ def eos_token_id (self ) -> int :
123+ return self .tokenizer .eos_id
124+
125+ @property
126+ def is_fast (self ) -> bool :
127+ return True
128+
129+ @property
130+ def vocab_size (self ) -> int :
131+ return len (self ._vocab )
132+
133+ def __len__ (self ) -> int :
134+ return self .vocab_size
135+
105136 def __call__ (
106137 self ,
107138 prompt : str ,
@@ -117,9 +148,12 @@ def __call__(
117148
118149 return Encoding (input_ids = input_ids )
119150
120- def get_added_vocab (self ) -> List [str ]:
151+ def get_vocab (self ) -> Dict [str , int ]:
152+ return self ._vocab
153+
154+ def get_added_vocab (self ) -> Dict [str , int ]:
121155 # Mistral tokenizers have no added vocabulary
122- return []
156+ return {}
123157
124158 def encode (self , prompt : str ) -> List [int ]:
125159 # `encode` should only be used for prompt completion
@@ -141,7 +175,7 @@ def apply_chat_template(self,
141175 return encoded .tokens
142176
143177 def convert_tokens_to_string (self , tokens : List [str ]) -> str :
144- if self ._is_tekken :
178+ if isinstance ( self .tokenizer , Tekkenizer ) :
145179 return "" .join (tokens )
146180 else :
147181 return self .tokenizer .decode (tokens ) # type: ignore[arg-type]
@@ -151,14 +185,11 @@ def decode(self, ids: Union[List[int], int]) -> str:
151185 ids = [ids ]
152186 return self .tokenizer .decode (ids )
153187
154- @property
155- def eos_token_id (self ):
156- return self .tokenizer .eos_id
157-
158188 def convert_ids_to_tokens (
159- self ,
160- ids : List [int ],
161- skip_special_tokens : Optional [bool ] = True ) -> List [str ]:
189+ self ,
190+ ids : List [int ],
191+ skip_special_tokens : bool = True ,
192+ ) -> List [str ]:
162193 # TODO(Patrick) - potentially allow special tokens to not be skipped
163194 assert (
164195 skip_special_tokens
@@ -170,6 +201,3 @@ def convert_ids_to_tokens(
170201
171202 tokens = [self .tokenizer .id_to_piece (id ) for id in ids ]
172203 return tokens
173-
174- def __len__ (self ):
175- return self .vocab_size
0 commit comments