diff --git a/tokenizer/base.py b/tokenizer/base.py new file mode 100644 index 000000000..75998b32b --- /dev/null +++ b/tokenizer/base.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Abstract base class for all tokenizer classes in python matching c++ interface. +""" + +# Standard +from abc import ABC, abstractmethod +from typing import List + + +class TokenizerBase(ABC): + __doc__ = __doc__ + + @abstractmethod + def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]: + """Encode the given string and optionally include bos/eos tokens""" + + @abstractmethod + def decode(self, ids: List[int]) -> str: + """Decode the given token ids into a string""" + + @abstractmethod + def bos_id(self) -> int: + """The id of the begin-of-string token""" + + @abstractmethod + def eos_id(self) -> int: + """The id of the end-of-string token""" diff --git a/tokenizer/hf_tokenizer.py b/tokenizer/hf_tokenizer.py new file mode 100644 index 000000000..7ad5807d1 --- /dev/null +++ b/tokenizer/hf_tokenizer.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Standard +from typing import List, Optional +import json +import os + +# Third Party +from tokenizers import Tokenizer + +# Local +from .base import TokenizerBase + + +class HFTokenizer(TokenizerBase): + """ + Wrapper around the Huggingface `tokenizers` library for API compatibility + """ + + def __init__(self, file_path: str): + # If the path is a directory, look for "tokenizer.json" which is + # standard for transformers checkpoints and also look for the + # "tokenizer_config.json" file to parse eos/bos tokens + if os.path.isdir(file_path): + tokenizer_path = os.path.join(file_path, "tokenizer.json") + tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json") + else: + tokenizer_path = file_path + tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json") + if not os.path.isfile(tokenizer_path): + tokenizer_config_path = None + + # Load the tokenizer itself + self._tokenizer = Tokenizer.from_file(tokenizer_path) + + # If available, parse bos/eos tokens from the tokenizer config + self._bos_id, self._eos_id = None, None + if tokenizer_config_path is not None: + with open(tokenizer_config_path, "r") as handle: + tok_config = json.load(handle) + bos_token = tok_config.get("bos_token") + eos_token = tok_config.get("eos_token") + if bos_token is not None: + self._bos_id = self._tokenizer.token_to_id(bos_token) + if eos_token is not None: + self._eos_id = self._tokenizer.token_to_id(eos_token) + + # If no eos/bos tokens found, go looking for them! + if None in [self._bos_id, self._eos_id]: + tok_content = json.loads(self._tokenizer.to_str()) + if self._bos_id is None: + self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"]) + if self._eos_id is None: + self._eos_id = self._look_for_special_token(tok_content, ["end", "text"]) + + assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens" + + @staticmethod + def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]: + candidate_toks = added_tokens + for search_str in search_strs: + candidate_toks = [ + tok for tok in candidate_toks + if tok["special"] and search_str in tok["content"] + ] + if len(candidate_toks) == 1: + return candidate_toks[0]["id"] + + def encode( + self, + s: str, + *, + bos: bool = False, + eos: bool = False, + ) -> List[int]: + res = self._tokenizer.encode(s, add_special_tokens=bos).ids + if eos and (not res or res[-1] != self._eos_token): + res.append(self._eos_token) + return res + + def decode(self, ids: List[int]) -> str: + return self._tokenizer.decode(ids) + + def bos_id(self) -> int: + return self._bos_id + + def eos_id(self) -> int: + return self._eos_id diff --git a/tokenizer/tiktoken.py b/tokenizer/tiktoken.py index 9e9fe2264..30eb98624 100644 --- a/tokenizer/tiktoken.py +++ b/tokenizer/tiktoken.py @@ -23,6 +23,8 @@ import tiktoken from tiktoken.load import load_tiktoken_bpe +from .base import TokenizerBase + logger = getLogger(__name__) @@ -38,7 +40,7 @@ class Message(TypedDict): Dialog = Sequence[Message] -class Tokenizer: +class Tokenizer(TokenizerBase): """ tokenizing and encoding/decoding text using the Tiktoken tokenizer. """ diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 511cf1f35..0fd9c58b9 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -204,6 +204,7 @@ class TokenizerArgs: tokenizer_path: Optional[Union[Path, str]] = None is_sentencepiece: bool = False is_tiktoken: bool = False + is_hf_tokenizer: bool = False t: Optional[Any] = None def __post_init__(self): @@ -213,6 +214,7 @@ def __post_init__(self): self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) self.is_tiktoken = True self.is_sentencepiece = False + self.is_hf_tokenizer = False return except: pass @@ -223,12 +225,25 @@ def __post_init__(self): self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) self.is_tiktoken = False self.is_sentencepiece = True + self.is_hf_tokenizer = False + return + except: + pass + + try: + from tokenizer.hf_tokenizer import HFTokenizer + + self.t = HFTokenizer(str(self.tokenizer_path)) + self.is_tiktoken = False + self.is_sentencepiece = False + self.is_hf_tokenizer = True return except: pass self.is_tiktoken = False self.is_sentencepiece = False + self.is_hf_tokenizer = False self.t = None return @@ -240,16 +255,27 @@ def validate_model( if model is None: return - if self.is_tiktoken == self.is_sentencepiece: + if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece + is_hf_tokenizer = self.is_hf_tokenizer use_tiktoken = model.config.use_tiktoken + use_hf_tokenizer = model.config.use_hf_tokenizer + use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) - if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): + if ( + (is_tiktoken and not use_tiktoken) or + (is_hf_tokenizer and not use_hf_tokenizer) or + (is_sentencepiece and not use_sentencepiece) + ): raise RuntimeError( - f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}" + "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( + tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer), + tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer), + model_description, + ) ) return @@ -605,5 +631,9 @@ def _initialize_model( return model -def tokenizer_setting_to_name(tiktoken: bool = False) -> str: - return "TikToken" if tiktoken else "SentencePiece" +def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str: + if tiktoken: + return "TikToken" + if tokenizers: + return "Tokenizers" + return "SentencePiece" \ No newline at end of file diff --git a/torchchat/model.py b/torchchat/model.py index 7868b6593..11f3dc167 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -270,7 +270,9 @@ class TransformerArgs: norm_eps: float = 1e-5 multiple_of: int = 256 ffn_dim_multiplier: Optional[int] = None + # Select the desired tokenizer. Defaults to sentencepiece use_tiktoken: bool = False + use_hf_tokenizer: bool = False max_seq_length: int = 8192 rope_scaling: Optional[Dict[str, Any]] = None # For pipeline parallel @@ -327,12 +329,14 @@ class ModelArgs: model_type: ModelType transformer_args: Dict[str, Dict[str, Any]] use_tiktoken: bool + use_hf_tokenizer: bool def __init__( self, transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType = ModelType.TextOnly, use_tiktoken: bool = False, + use_hf_tokenizer: bool = False, ) -> None: self._sanity_check(transformer_args, model_type) @@ -341,6 +345,7 @@ def __init__( # Model-level attributes self.use_tiktoken = use_tiktoken + self.use_hf_tokenizer = use_hf_tokenizer def _sanity_check( self, @@ -367,7 +372,8 @@ def from_params(cls, params_path): } use_tiktoken = loaded_params.get("use_tiktoken", False) - return cls(transformer_args, model_type, use_tiktoken) + use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False) + return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer) @classmethod def from_table(cls, name: str):