55# LICENSE file in the root directory of this source tree.
66
77# Standard
8- from typing import List , Optional
8+ from typing import Dict , List , Optional
99import json
1010import os
1111
1212# Third Party
13+ import jinja2
1314from tokenizers import Tokenizer
1415
1516# Local
@@ -37,6 +38,9 @@ def __init__(self, file_path: str):
3738 # Load the tokenizer itself
3839 self ._tokenizer = Tokenizer .from_file (tokenizer_path )
3940
41+ # Load the chat template if we have a config path
42+ self ._chat_template : Optional [jinja2 .Template ] = None
43+
4044 # If available, parse bos/eos tokens from the tokenizer config
4145 self ._bos_id , self ._eos_id = None , None
4246 if tokenizer_config_path is not None :
@@ -48,6 +52,8 @@ def __init__(self, file_path: str):
4852 self ._bos_id = self ._tokenizer .token_to_id (bos_token )
4953 if eos_token is not None :
5054 self ._eos_id = self ._tokenizer .token_to_id (eos_token )
55+ if chat_template_str := tok_config .get ("chat_template" ):
56+ self ._chat_template = jinja2 .Template (chat_template_str )
5157
5258 # If no eos/bos tokens found, go looking for them!
5359 if None in [self ._bos_id , self ._eos_id ]:
@@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
7076 if len (candidate_toks ) == 1 :
7177 return candidate_toks [0 ]["id" ]
7278
79+ ## Interface ##
80+
7381 def encode (
7482 self ,
7583 s : str ,
@@ -90,3 +98,21 @@ def bos_id(self) -> int:
9098
9199 def eos_id (self ) -> int :
92100 return self ._eos_id
101+
102+ ## Additional Public Methods ##
103+
104+ def has_chat_template (self ) -> bool :
105+ return bool (self ._chat_template )
106+
107+ def apply_chat_template (
108+ self ,
109+ dialog : List [Dict [str , str ]],
110+ add_generation_prompt : bool = False ,
111+ ) -> str :
112+ """If configured with a chat template, apply it to the list of messages
113+ """
114+ if not self ._chat_template :
115+ raise ValueError ("No chat template configured!" )
116+ return self ._chat_template .render (
117+ messages = dialog , add_generation_prompt = add_generation_prompt
118+ )
0 commit comments