Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import sys
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -237,23 +238,24 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
speculative_builder_args.pte_path = None
return speculative_builder_args

class TokenizerType(Enum):
NONE = 0
TIKTOKEN = 1
SENTENCEPIECE = 2
HF_TOKENIZER = 3

@dataclass
class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
is_sentencepiece: bool = False
is_tiktoken: bool = False
is_hf_tokenizer: bool = False
tokenizer_type: TokenizerType = TokenizerType.NONE
t: Optional[Any] = None

def __post_init__(self):
try:
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer

self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
self.is_tiktoken = True
self.is_sentencepiece = False
self.is_hf_tokenizer = False
self.tokenizer_type = TokenizerType.TIKTOKEN
return
except:
pass
Expand All @@ -262,9 +264,7 @@ def __post_init__(self):
from sentencepiece import SentencePieceProcessor

self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = True
self.is_hf_tokenizer = False
self.tokenizer_type = TokenizerType.SENTENCEPIECE
return
except:
pass
Expand All @@ -273,19 +273,24 @@ def __post_init__(self):
from tokenizer.hf_tokenizer import HFTokenizer

self.t = HFTokenizer(str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = True
self.tokenizer_type = TokenizerType.HF_TOKENIZER
return
except:
pass

self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = False
self.tokenizer_type = TokenizerType.NONE
self.t = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really have to set these as none again since we already set them at the very top.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually drop all the logic here after the HF tokenizer check, tokenizer_type and .t are already set to these by default

return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return is not needed


def is_tiktoken(self) -> bool:
return self.tokenizer_type == TokenizerType.TIKTOKEN

def is_sentencepiece(self) -> bool:
return self.tokenizer_type == TokenizerType.SENTENCEPIECE

def is_hf_tokenizer(self) -> bool:
return self.tokenizer_type == TokenizerType.HF_TOKENIZER

def validate_model(
self,
model: Optional[Model],
Expand All @@ -294,12 +299,14 @@ def validate_model(
if model is None:
return

if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:

is_tiktoken = self.is_tiktoken()
is_sentencepiece = self.is_sentencepiece()
is_hf_tokenizer = self.is_hf_tokenizer()

if sum([is_tiktoken, is_hf_tokenizer, is_sentencepiece]) != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can replace this by just checking if the tokenizer enum is None

raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
is_hf_tokenizer = self.is_hf_tokenizer
use_tiktoken = model.config.use_tiktoken
use_hf_tokenizer = model.config.use_hf_tokenizer
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
Expand Down Expand Up @@ -651,13 +658,13 @@ def do_nothing(max_batch_size, max_seq_length):
model = torch.load(builder_args.snapshot_path, weights_only=False)
except Exception:
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
# _active_backend() does not allow DSO & AOTI to be true.
# _active_backend() does not allow DSO & AOTI to be true.
# Choose either.
from torchchat.utils.build_utils import set_backend
set_backend (dso=True, pte=False, aoti_package=False)
if (model.config != config):
raise RuntimeError("loaded model architecture mismatch")
##
##
## import all libraries with custom kernels ans custom operators
## that quantize may be pulling in
##
Expand Down Expand Up @@ -792,4 +799,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
return "TikToken"
if tokenizers:
return "Tokenizers"
return "SentencePiece"
return "SentencePiece"
2 changes: 1 addition & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def main(args):

if tokenizer_args is None:
tokenizer_type = "0"
elif tokenizer_args.is_sentencepiece:
elif tokenizer_args.is_sentencepiece():
tokenizer_type = "2" # Corresponding to llama2
else:
tokenizer_type = "3" # Corresponding to llama3
Expand Down
4 changes: 2 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,14 @@ def __init__(
# must use tiktokenizer.
# Piggy backing off of this flag then for now to identify llama3
# without prompting user.
self.is_llama3_model = self.tokenizer_args.is_tiktoken
self.is_llama3_model = self.tokenizer_args.is_tiktoken()
if self.is_llama3_model:
self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
if generator_args.chat_mode:
logger.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
elif self.tokenizer_args.is_hf_tokenizer:
elif self.tokenizer_args.is_hf_tokenizer():
if not self.tokenizer.has_chat_template():
raise ValueError("Tokenizer must have a chat template")
self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)
Expand Down