99import os
1010import textwrap
1111import time
12+
13+ from abc import ABC , abstractmethod
1214from dataclasses import dataclass
1315from pathlib import Path
1416from typing import List , Optional , Tuple
2830from cli import add_arguments_for_verb , arg_init , check_args
2931from utils .device_info import get_device_info
3032
31- B_INST , E_INST = "[INST]" , "[/INST]"
32- B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
33-
3433
35- class ChatFormat :
34+ class _ChatFormatter ( ABC ) :
3635 def __init__ (self , tokenizer ):
3736 self .tokenizer = tokenizer
3837
39- def encode_header (self , message ) -> List [int ]:
38+ @abstractmethod
39+ def encode_dialog_prompt (self , dialog ) -> List [int ]:
40+ raise NotImplementedError ()
41+
42+
43+ class Llama3ChatFormatter (_ChatFormatter ):
44+ """Format a chat prompt using special tokens to demarcate roles and messages.
45+
46+ Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
47+
48+ """
49+
50+ def encode_header (self , role ) -> List [int ]:
4051 tokens = []
4152 tokens .append (self .tokenizer .special_tokens ["<|start_header_id|>" ])
42- tokens .extend (self .tokenizer .encode (message [ " role" ] , bos = False , eos = False ))
53+ tokens .extend (self .tokenizer .encode (role , bos = False , eos = False ))
4354 tokens .append (self .tokenizer .special_tokens ["<|end_header_id|>" ])
4455 tokens .extend (self .tokenizer .encode ("\n \n " , bos = False , eos = False ))
4556 return tokens
4657
4758 def encode_message (self , message ) -> List [int ]:
48- tokens = self .encode_header (message )
59+ tokens = self .encode_header (message . role )
4960 tokens .extend (
5061 self .tokenizer .encode (message ["content" ].strip (), bos = False , eos = False )
5162 )
@@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
6273 return tokens
6374
6475
76+ B_INST , E_INST = "[INST]" , "[/INST]"
77+ B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
78+
79+
80+ class Llama2ChatFormatter (_ChatFormatter ):
81+ def encode_dialog_prompt (self , dialog ) -> List [int ]:
82+ tokens = self .tokenizer .encode (f"{ B_INST } " )
83+ first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
84+ for message in dialog :
85+ content = message ["content" ].strip ()
86+ if message ["role" ] == "system" :
87+ encoded = self .tokenizer .encode (f"{ B_SYS } \n { content } \n { E_SYS } " )
88+ first_message = False
89+ elif message ["role" ] == "user" :
90+ encoded = [self .tokenizer .bos_id ()] + self .tokenizer .encode (
91+ f"{ B_INST if first_message else '' } { content } { E_INST } "
92+ )
93+ first_message = True
94+ elif message ["role" ] == "assistant" :
95+ encoded = self .tokenizer .encode (f"{ content } \n \n " ) + [
96+ self .tokenizer .eos_id ()
97+ ]
98+ tokens += encoded
99+ return tokens
100+
101+
65102@dataclass
66103class GeneratorArgs :
67- prompt : str = "torchchat is pronounced torch-chat and is so cool because"
104+ prompt : Optional [str ] = (
105+ None # When passed into the Generator, this will be used as the system prompt
106+ )
68107 encoded_prompt : Optional [torch .Tensor ] = None
69108 chat_mode : bool = False
70109 gui_mode : bool = False
@@ -188,7 +227,7 @@ def __init__(
188227 ))
189228 # fmt: on
190229 # raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")
191-
230+ self . system_prompt = generator_args . prompt
192231 self .tokenizer = _initialize_tokenizer (self .tokenizer_args )
193232
194233 # Right now the assumption is only llama3 uses tiktokenizer and it
@@ -200,6 +239,11 @@ def __init__(
200239 logging .debug (
201240 "Llama3 model detected in chat mode. Using updated sentence schemas"
202241 )
242+ self .chat_formatter = (
243+ Llama3ChatFormatter (self .tokenizer )
244+ if self .is_llama3_model
245+ else Llama2ChatFormatter (self .tokenizer )
246+ )
203247
204248 self .builder_args .setup_caches = False
205249 self .model = _initialize_model (self .builder_args , self .quantize , self .tokenizer )
@@ -641,8 +685,7 @@ def chat(
641685 )
642686 if get_system_prompt == "y" or get_system_prompt == "Y" :
643687 self .system_prompt = input ("What is your system prompt? \n " )
644- if self .is_llama3_model :
645- self .chat_formatter = ChatFormat (self .tokenizer )
688+
646689 else :
647690 max_seq_length = min (
648691 encoded .size (0 ) + generator_args .max_new_tokens ,
@@ -685,7 +728,7 @@ def chat(
685728 prompt , bos = True , device = self .builder_args .device
686729 )
687730 else :
688- if self .system_prompt is not None :
731+ if self .system_prompt :
689732 encoded = self .chat_formatter .encode_dialog_prompt (
690733 [
691734 {"role" : "system" , "content" : self .system_prompt },
0 commit comments