55# LICENSE file in the root directory of this source tree.
66
77# Standard
8- from typing import List
8+ from typing import List , Optional
99import json
10+ import os
1011
1112# Third Party
1213from tokenizers import Tokenizer
@@ -21,26 +22,53 @@ class TokenizersTokenizer(TokenizerBase):
2122 """
2223
2324 def __init__ (self , file_path : str ):
24- self ._tokenizer = Tokenizer .from_file (file_path )
25- # The BOS and EOS tokens are not easily visible from the tokenizer
26- # object itself, so we extract them at construction with a sample call
27- self ._bos_token = self ._tokenizer .encode ("Test" , add_special_tokens = True ).ids [0 ]
28- # There is no explicit BOS token in many tokenizers, so we look for a
29- # single special token that most resembles the BOS token.
30- self ._eos_token = None
31- tok_content = json .loads (self ._tokenizer .to_str ())
32- end_toks = [
33- tok for tok in tok_content ['added_tokens' ]
34- if tok ["special" ] and "end" in tok ["content" ]
35- ]
36- assert end_toks , "Unable to find an EOS token in the added tokens"
37- if len (end_toks ) > 1 :
38- end_text_toks = [
39- tok for tok in end_toks if "text" in tok ["content" ]
25+ # If the path is a directory, look for "tokenizer.json" which is
26+ # standard for transformers checkpoints and also look for the
27+ # "tokenizer_config.json" file to parse eos/bos tokens
28+ if os .path .isdir (file_path ):
29+ tokenizer_path = os .path .join (file_path , "tokenizer.json" )
30+ tokenizer_config_path = os .path .join (file_path , "tokenizer_config.json" )
31+ else :
32+ tokenizer_path = file_path
33+ tokenizer_config_path = os .path .join (os .path .dirname (file_path ), "tokenizer_config.json" )
34+ if not os .path .isfile (tokenizer_path ):
35+ tokenizer_config_path = None
36+
37+ # Load the tokenizer itself
38+ self ._tokenizer = Tokenizer .from_file (tokenizer_path )
39+
40+ # If available, parse bos/eos tokens from the tokenizer config
41+ self ._bos_id , self ._eos_id = None , None
42+ if tokenizer_config_path is not None :
43+ with open (tokenizer_config_path , "r" ) as handle :
44+ tok_config = json .load (handle )
45+ bos_token = tok_config .get ("bos_token" )
46+ eos_token = tok_config .get ("eos_token" )
47+ if bos_token is not None :
48+ self ._bos_id = self ._tokenizer .token_to_id (bos_token )
49+ if eos_token is not None :
50+ self ._eos_id = self ._tokenizer .token_to_id (eos_token )
51+
52+ # If no eos/bos tokens found, go looking for them!
53+ if None in [self ._bos_id , self ._eos_id ]:
54+ tok_content = json .loads (self ._tokenizer .to_str ())
55+ if self ._bos_id is None :
56+ self ._bos_id = self ._look_for_special_token (tok_content , ["begin" , "text" ])
57+ if self ._eos_id is None :
58+ self ._eos_id = self ._look_for_special_token (tok_content , ["end" , "text" ])
59+
60+ assert None not in [self ._bos_id , self ._eos_id ], "Unable to find an BOS/EOS tokens"
61+
62+ @staticmethod
63+ def _look_for_special_token (added_tokens : dict , search_strs : List [str ]) -> Optional [int ]:
64+ candidate_toks = added_tokens
65+ for search_str in search_strs :
66+ candidate_toks = [
67+ tok for tok in candidate_toks
68+ if tok ["special" ] and search_str in tok ["content" ]
4069 ]
41- if len (end_text_toks ) == 1 :
42- self ._eos_token = end_text_toks [0 ]["id" ]
43- assert self ._eos_token is not None , "Unable to find an EOS token in the added tokens"
70+ if len (candidate_toks ) == 1 :
71+ return candidate_toks [0 ]["id" ]
4472
4573 def encode (
4674 self ,
@@ -58,7 +86,7 @@ def decode(self, ids: List[int]) -> str:
5886 return self ._tokenizer .decode (ids )
5987
6088 def bos_id (self ) -> int :
61- return self ._bos_token
89+ return self ._bos_id
6290
6391 def eos_id (self ) -> int :
64- return self ._eos_token
92+ return self ._eos_id
0 commit comments