Skip to content

Unified tokenizer type onboarding #1540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tokenizer/tokenizer_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum

class TokenizerType(Enum):
NONE = 0
TIKTOKEN = 1
SENTENCEPIECE = 2
HF_TOKENIZER = 3

def is_tiktoken(self):
return self == TokenizerType.TIKTOKEN
def is_sentencepiece(self):
return self == TokenizerType.SENTENCEPIECE
def is_hf_tokenizer(self):
return self == TokenizerType.HF_TOKENIZER
def is_none(self):
return self == TokenizerType.NONE
23 changes: 5 additions & 18 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,7 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
speculative_builder_args.pte_path = None
return speculative_builder_args

class TokenizerType(Enum):
NONE = 0
TIKTOKEN = 1
SENTENCEPIECE = 2
HF_TOKENIZER = 3
from tokenizer.tokenizer_type import TokenizerType

@dataclass
class TokenizerArgs:
Expand Down Expand Up @@ -278,15 +274,6 @@ def __post_init__(self):
except:
pass

def is_tiktoken(self) -> bool:
return self.tokenizer_type == TokenizerType.TIKTOKEN

def is_sentencepiece(self) -> bool:
return self.tokenizer_type == TokenizerType.SENTENCEPIECE

def is_hf_tokenizer(self) -> bool:
return self.tokenizer_type == TokenizerType.HF_TOKENIZER

def validate_model(
self,
model: Optional[Model],
Expand All @@ -295,12 +282,12 @@ def validate_model(
if model is None:
return

if self.tokenizer_type == TokenizerType.NONE:
if self.tokenizer_type.is_none():
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken()
is_sentencepiece = self.is_sentencepiece()
is_hf_tokenizer = self.is_hf_tokenizer()
is_tiktoken = self.tokenizer_type.is_tiktoken()
is_sentencepiece = self.tokenizer_type.is_sentencepiece()
is_hf_tokenizer = self.tokenizer_type.is_hf_tokenizer()

use_tiktoken = model.config.use_tiktoken
use_hf_tokenizer = model.config.use_hf_tokenizer
Expand Down
4 changes: 2 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,14 @@ def __init__(
# must use tiktokenizer.
# Piggy backing off of this flag then for now to identify llama3
# without prompting user.
self.is_llama3_model = self.tokenizer_args.is_tiktoken()
self.is_llama3_model = self.tokenizer_args.tokenizer_type.is_tiktoken()
if self.is_llama3_model:
self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
if generator_args.chat_mode:
logger.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
elif self.tokenizer_args.is_hf_tokenizer():
elif self.tokenizer_args.tokenizer_type.is_hf_tokenizer():
if not self.tokenizer.has_chat_template():
raise ValueError("Tokenizer must have a chat template")
self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)
Expand Down
Loading