4747
4848logger = logging .getLogger (__name__ )
4949
50+ ## Chat Formatters #############################################################
5051
5152class _ChatFormatter (ABC ):
53+
54+ # Messages can arrive as a standard dict with "role" and "content" as
55+ # strings, or where "content" is a list of objects with "text" fields.
56+ MESSAGE_TYPE = Dict [str , Union [str , List [Dict [str , str ]]]]
57+
58+ # A dialog is a sequence of messages
59+ DIALOG_TYPE = List [MESSAGE_TYPE ]
60+
5261 def __init__ (self , tokenizer ):
5362 self .tokenizer = tokenizer
5463
5564 @abstractmethod
56- def encode_dialog_prompt (self , dialog ) -> List [int ]:
57- raise NotImplementedError ()
65+ def encode_dialog_prompt (
66+ self ,
67+ dialog : DIALOG_TYPE ,
68+ add_generation_prompt : bool ,
69+ ) -> List [int ]:
70+ """Encode a sequence of messages into a sequence of token IDs, including
71+ the chat template
72+
73+ Args:
74+ dialog (DIALOG_TYPE): The sequence of dialog messages to encode.
75+ This will be the additional messages on top of those that have
76+ already been processed.
77+ add_generation_prompt (bool): Whether to include a generation prompt
78+ at the end of the encoded sequence.
79+
80+ Returns:
81+ List[int]: A list of token IDs representing the encoded prompt.
82+ """
5883
5984
6085class Llama3ChatFormatter (_ChatFormatter ):
@@ -64,16 +89,16 @@ class Llama3ChatFormatter(_ChatFormatter):
6489
6590 """
6691
67- def encode_header (self , role ) -> List [int ]:
92+ def _encode_header (self , role ) -> List [int ]:
6893 tokens = []
6994 tokens .append (self .tokenizer .special_tokens ["<|start_header_id|>" ])
7095 tokens .extend (self .tokenizer .encode (role , bos = False , eos = False ))
7196 tokens .append (self .tokenizer .special_tokens ["<|end_header_id|>" ])
7297 tokens .extend (self .tokenizer .encode ("\n \n " , bos = False , eos = False ))
7398 return tokens
7499
75- def encode_message (self , message ) -> List [int ]:
76- tokens = self .encode_header (message ["role" ])
100+ def _encode_message (self , message : _ChatFormatter . MESSAGE_TYPE ) -> List [int ]:
101+ tokens = self ._encode_header (message ["role" ])
77102 if isinstance (message ["content" ], str ):
78103 tokens .extend (
79104 self .tokenizer .encode (message ["content" ], bos = False , eos = False )
@@ -88,54 +113,79 @@ def encode_message(self, message) -> List[int]:
88113 tokens .append (self .tokenizer .special_tokens ["<|eot_id|>" ])
89114 return tokens
90115
91- def encode_dialog_prompt (self , dialog ) -> List [int ]:
116+ def encode_dialog_prompt (
117+ self ,
118+ dialog : _ChatFormatter .DIALOG_TYPE ,
119+ add_generation_prompt : bool ,
120+ ) -> List [int ]:
92121 tokens = []
93122 tokens .append (self .tokenizer .special_tokens ["<|begin_of_text|>" ])
94123 for message in dialog :
95- tokens .extend (self .encode_message (message ))
124+ tokens .extend (self ._encode_message (message ))
96125 # Add the start of an assistant message for the model to complete.
97- tokens .extend (self .encode_header ("assistant" )) # Pass role directly as a string
126+ if add_generation_prompt :
127+ tokens .extend (self ._encode_header ("assistant" )) # Pass role directly as a string
98128 return tokens
99129
100130
101- B_INST , E_INST = "[INST]" , "[/INST]"
102- B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
131+ class Llama2ChatFormatter (_ChatFormatter ):
132+ """
133+ Chat formatting for Llama2
134+ CITE: https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
135+ """
136+
137+ B_INST , E_INST = "[INST] " , " [/INST]"
138+ B_SYS , E_SYS = "<<SYS>>\n " , "\n <</SYS>>\n \n "
103139
140+ @staticmethod
141+ def _get_content_str (message : _ChatFormatter .MESSAGE_TYPE ) -> str :
142+ if isinstance (message ["content" ], list ):
143+ return message ["content" ][0 ]["text" ]
144+ return message ["content" ]
104145
105- class Llama2ChatFormatter (_ChatFormatter ):
106- def encode_dialog_prompt (self , dialog ) -> List [int ]:
107- tokens = self .tokenizer .encode (f"{ B_INST } " )
108- first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
146+ def encode_dialog_prompt (
147+ self ,
148+ dialog : _ChatFormatter .DIALOG_TYPE ,
149+ add_generation_prompt : bool , # UNUSED
150+ ) -> List [int ]:
151+ new_turn = True
152+ tokens = []
109153 for message in dialog :
110- if isinstance (message ["content" ], list ):
111- content = message ["content" ][0 ]["text" ]
154+ if new_turn :
155+ tokens += self .tokenizer .encode (f"{ self .tokenizer .bos } { self .B_INST } " )
156+ content = self ._get_content_str (message ).strip ()
157+ role = message ["role" ]
158+ if role == "system" :
159+ tokens += self .tokenizer .encode (f"{ self .B_SYS } { content } { self .E_SYS } " )
160+ new_turn = False
161+ elif role == "user" :
162+ tokens += self .tokenizer .encode (f"{ content } { self .E_INST } " )
163+ new_turn = False
164+ elif role == "assistant" :
165+ tokens += self .tokenizer .encode (f" { content } { self .tokenizer .eos } \n " )
166+ new_turn = True
112167 else :
113- content = message ["content" ]
114- content = content .strip ()
115- if message ["role" ] == "system" :
116- encoded = self .tokenizer .encode (f"{ B_SYS } \n { content } \n { E_SYS } " )
117- first_message = False
118- elif message ["role" ] == "user" :
119- encoded = [self .tokenizer .bos_id ()] + self .tokenizer .encode (
120- f"{ B_INST if first_message else '' } { content } { E_INST } "
121- )
122- first_message = True
123- elif message ["role" ] == "assistant" :
124- encoded = self .tokenizer .encode (f"{ content } \n \n " ) + [
125- self .tokenizer .eos_id ()
126- ]
127- tokens += encoded
168+ raise ValueError ("Invalid role in dialog." )
128169 return tokens
129170
130171
172+
131173class HFTokenizerChatFormatter (_ChatFormatter ):
132174 """Chat formatter that uses the built-in formatting capabilities of an HF
133175 tokenizer instance
134176 """
135- def encode_dialog_prompt (self , dialog ) -> List [int ]:
136- rendered = self .tokenizer .apply_chat_template (dialog , add_generation_prompt = True )
177+ def encode_dialog_prompt (
178+ self ,
179+ dialog : _ChatFormatter .DIALOG_TYPE ,
180+ add_generation_prompt : bool ,
181+ ) -> List [int ]:
182+ rendered = self .tokenizer .apply_chat_template (
183+ dialog , add_generation_prompt = add_generation_prompt
184+ )
185+ logger .debug ("Formatted chat prompt:\n %s" , rendered )
137186 return self .tokenizer .encode (rendered )
138187
188+ ## Generation ##################################################################
139189
140190@dataclass
141191class GeneratorArgs :
@@ -1040,38 +1090,21 @@ def chat(
10401090 if prompt == "/bye" :
10411091 print ("Exiting Chat.\n " )
10421092 break
1043- if not self .is_llama3_model :
1044- if self .system_prompt :
1045- prompt = f"{ B_INST } { B_SYS } \n { self .system_prompt .strip ()} \n { E_SYS } \n \n { prompt .strip ()} { E_INST } "
1046- self .system_prompt = (
1047- None # can only provide system prompt on first interaction
1048- )
1049- else :
1050- prompt = f"{ B_INST } { prompt .strip ()} { E_INST } "
1051- encoded = self .encode_tokens (
1052- prompt , bos = self .model .config .tokenizer_prepend_bos , device = self .builder_args .device
1053- )
1054- else :
1055- if self .system_prompt :
1056- encoded = self .chat_formatter .encode_dialog_prompt (
1057- [
1058- {"role" : "system" , "content" : self .system_prompt },
1059- {"role" : "user" , "content" : prompt },
1060- ]
1061- )
1062- self .system_prompt = None
1063- elif is_first_sample :
1064- encoded = self .chat_formatter .encode_dialog_prompt (
1065- [{"role" : "user" , "content" : prompt }]
1066- )
1067- else :
1068- encoded = self .chat_formatter .encode_message (
1069- {"role" : "user" , "content" : prompt }
1070- )
1071- encoded .extend (self .chat_formatter .encode_header ("assistant" ))
1072- encoded = torch .tensor (
1073- encoded , dtype = torch .int , device = self .builder_args .device
1093+
1094+ # Encode the additional messages added in this dialog turn. If
1095+ # this is the first turn, that includes any system prompt.
1096+ messages_to_encode = []
1097+ if is_first_sample and self .system_prompt :
1098+ messages_to_encode .append (
1099+ {"role" : "system" , "content" : self .system_prompt }
10741100 )
1101+ messages_to_encode .append ({"role" : "system" , "content" : prompt })
1102+ encoded = self .chat_formatter .encode_dialog_prompt (
1103+ messages_to_encode , add_generation_prompt = True ,
1104+ )
1105+ encoded = torch .tensor (
1106+ encoded , dtype = torch .int , device = self .builder_args .device
1107+ )
10751108 if encoded .size (0 ) + start_pos > max_seq_length :
10761109 print (
10771110 "This prompt would take us past the max_seq_length. Ending Conversation."
0 commit comments