Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7146029

Browse files
authored
Merge branch 'main' into angelayi/aoti_metadata
2 parents b2b93c5 + 9480258 commit 7146029

File tree

5 files changed

+169
-7
lines changed

5 files changed

+169
-7
lines changed

tokenizer/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Abstract base class for all tokenizer classes in python matching c++ interface.
8+
"""
9+
10+
# Standard
11+
from abc import ABC, abstractmethod
12+
from typing import List
13+
14+
15+
class TokenizerBase(ABC):
16+
__doc__ = __doc__
17+
18+
@abstractmethod
19+
def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]:
20+
"""Encode the given string and optionally include bos/eos tokens"""
21+
22+
@abstractmethod
23+
def decode(self, ids: List[int]) -> str:
24+
"""Decode the given token ids into a string"""
25+
26+
@abstractmethod
27+
def bos_id(self) -> int:
28+
"""The id of the begin-of-string token"""
29+
30+
@abstractmethod
31+
def eos_id(self) -> int:
32+
"""The id of the end-of-string token"""

tokenizer/hf_tokenizer.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Standard
8+
from typing import List, Optional
9+
import json
10+
import os
11+
12+
# Third Party
13+
from tokenizers import Tokenizer
14+
15+
# Local
16+
from .base import TokenizerBase
17+
18+
19+
class HFTokenizer(TokenizerBase):
20+
"""
21+
Wrapper around the Huggingface `tokenizers` library for API compatibility
22+
"""
23+
24+
def __init__(self, file_path: str):
25+
# If the path is a directory, look for "tokenizer.json" which is
26+
# standard for transformers checkpoints and also look for the
27+
# "tokenizer_config.json" file to parse eos/bos tokens
28+
if os.path.isdir(file_path):
29+
tokenizer_path = os.path.join(file_path, "tokenizer.json")
30+
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json")
31+
else:
32+
tokenizer_path = file_path
33+
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json")
34+
if not os.path.isfile(tokenizer_path):
35+
tokenizer_config_path = None
36+
37+
# Load the tokenizer itself
38+
self._tokenizer = Tokenizer.from_file(tokenizer_path)
39+
40+
# If available, parse bos/eos tokens from the tokenizer config
41+
self._bos_id, self._eos_id = None, None
42+
if tokenizer_config_path is not None:
43+
with open(tokenizer_config_path, "r") as handle:
44+
tok_config = json.load(handle)
45+
bos_token = tok_config.get("bos_token")
46+
eos_token = tok_config.get("eos_token")
47+
if bos_token is not None:
48+
self._bos_id = self._tokenizer.token_to_id(bos_token)
49+
if eos_token is not None:
50+
self._eos_id = self._tokenizer.token_to_id(eos_token)
51+
52+
# If no eos/bos tokens found, go looking for them!
53+
if None in [self._bos_id, self._eos_id]:
54+
tok_content = json.loads(self._tokenizer.to_str())
55+
if self._bos_id is None:
56+
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"])
57+
if self._eos_id is None:
58+
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"])
59+
60+
assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens"
61+
62+
@staticmethod
63+
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]:
64+
candidate_toks = added_tokens
65+
for search_str in search_strs:
66+
candidate_toks = [
67+
tok for tok in candidate_toks
68+
if tok["special"] and search_str in tok["content"]
69+
]
70+
if len(candidate_toks) == 1:
71+
return candidate_toks[0]["id"]
72+
73+
def encode(
74+
self,
75+
s: str,
76+
*,
77+
bos: bool = False,
78+
eos: bool = False,
79+
) -> List[int]:
80+
res = self._tokenizer.encode(s, add_special_tokens=bos).ids
81+
if eos and (not res or res[-1] != self._eos_token):
82+
res.append(self._eos_token)
83+
return res
84+
85+
def decode(self, ids: List[int]) -> str:
86+
return self._tokenizer.decode(ids)
87+
88+
def bos_id(self) -> int:
89+
return self._bos_id
90+
91+
def eos_id(self) -> int:
92+
return self._eos_id

tokenizer/tiktoken.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import tiktoken
2424
from tiktoken.load import load_tiktoken_bpe
2525

26+
from .base import TokenizerBase
27+
2628

2729
logger = getLogger(__name__)
2830

@@ -38,7 +40,7 @@ class Message(TypedDict):
3840
Dialog = Sequence[Message]
3941

4042

41-
class Tokenizer:
43+
class Tokenizer(TokenizerBase):
4244
"""
4345
tokenizing and encoding/decoding text using the Tiktoken tokenizer.
4446
"""

torchchat/cli/builder.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ class TokenizerArgs:
215215
tokenizer_path: Optional[Union[Path, str]] = None
216216
is_sentencepiece: bool = False
217217
is_tiktoken: bool = False
218+
is_hf_tokenizer: bool = False
218219
t: Optional[Any] = None
219220

220221
def __post_init__(self):
@@ -224,6 +225,7 @@ def __post_init__(self):
224225
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
225226
self.is_tiktoken = True
226227
self.is_sentencepiece = False
228+
self.is_hf_tokenizer = False
227229
return
228230
except:
229231
pass
@@ -234,12 +236,25 @@ def __post_init__(self):
234236
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
235237
self.is_tiktoken = False
236238
self.is_sentencepiece = True
239+
self.is_hf_tokenizer = False
240+
return
241+
except:
242+
pass
243+
244+
try:
245+
from tokenizer.hf_tokenizer import HFTokenizer
246+
247+
self.t = HFTokenizer(str(self.tokenizer_path))
248+
self.is_tiktoken = False
249+
self.is_sentencepiece = False
250+
self.is_hf_tokenizer = True
237251
return
238252
except:
239253
pass
240254

241255
self.is_tiktoken = False
242256
self.is_sentencepiece = False
257+
self.is_hf_tokenizer = False
243258
self.t = None
244259
return
245260

@@ -251,16 +266,27 @@ def validate_model(
251266
if model is None:
252267
return
253268

254-
if self.is_tiktoken == self.is_sentencepiece:
269+
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
255270
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
256271

257272
is_tiktoken = self.is_tiktoken
258273
is_sentencepiece = self.is_sentencepiece
274+
is_hf_tokenizer = self.is_hf_tokenizer
259275
use_tiktoken = model.config.use_tiktoken
276+
use_hf_tokenizer = model.config.use_hf_tokenizer
277+
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
260278

261-
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
279+
if (
280+
(is_tiktoken and not use_tiktoken) or
281+
(is_hf_tokenizer and not use_hf_tokenizer) or
282+
(is_sentencepiece and not use_sentencepiece)
283+
):
262284
raise RuntimeError(
263-
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}"
285+
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
286+
tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer),
287+
tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer),
288+
model_description,
289+
)
264290
)
265291

266292
return
@@ -655,5 +681,9 @@ def _initialize_model(
655681
return model
656682

657683

658-
def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
659-
return "TikToken" if tiktoken else "SentencePiece"
684+
def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
685+
if tiktoken:
686+
return "TikToken"
687+
if tokenizers:
688+
return "Tokenizers"
689+
return "SentencePiece"

torchchat/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ class TransformerArgs:
270270
norm_eps: float = 1e-5
271271
multiple_of: int = 256
272272
ffn_dim_multiplier: Optional[int] = None
273+
# Select the desired tokenizer. Defaults to sentencepiece
273274
use_tiktoken: bool = False
275+
use_hf_tokenizer: bool = False
274276
max_seq_length: int = 8192
275277
rope_scaling: Optional[Dict[str, Any]] = None
276278
# For pipeline parallel
@@ -327,12 +329,14 @@ class ModelArgs:
327329
model_type: ModelType
328330
transformer_args: Dict[str, Dict[str, Any]]
329331
use_tiktoken: bool
332+
use_hf_tokenizer: bool
330333

331334
def __init__(
332335
self,
333336
transformer_args: Dict[str, Dict[str, Any]],
334337
model_type: ModelType = ModelType.TextOnly,
335338
use_tiktoken: bool = False,
339+
use_hf_tokenizer: bool = False,
336340
) -> None:
337341
self._sanity_check(transformer_args, model_type)
338342

@@ -341,6 +345,7 @@ def __init__(
341345

342346
# Model-level attributes
343347
self.use_tiktoken = use_tiktoken
348+
self.use_hf_tokenizer = use_hf_tokenizer
344349

345350
def _sanity_check(
346351
self,
@@ -367,7 +372,8 @@ def from_params(cls, params_path):
367372
}
368373

369374
use_tiktoken = loaded_params.get("use_tiktoken", False)
370-
return cls(transformer_args, model_type, use_tiktoken)
375+
use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False)
376+
return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer)
371377

372378
@classmethod
373379
def from_table(cls, name: str):

0 commit comments

Comments
 (0)