Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b1237ba

Browse files
committed
fix(chat): Refactor chat template logic to encapsulate all formatting in classes
The formatted strings may not be perfectly 1:1 with the previous impl, but they should be in line with the official model guidelines: * https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 * https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 2e694b0 commit b1237ba

File tree

1 file changed

+97
-64
lines changed

1 file changed

+97
-64
lines changed

torchchat/generate.py

Lines changed: 97 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,39 @@
4747

4848
logger = logging.getLogger(__name__)
4949

50+
## Chat Formatters #############################################################
5051

5152
class _ChatFormatter(ABC):
53+
54+
# Messages can arrive as a standard dict with "role" and "content" as
55+
# strings, or where "content" is a list of objects with "text" fields.
56+
MESSAGE_TYPE = Dict[str, Union[str, List[Dict[str, str]]]]
57+
58+
# A dialog is a sequence of messages
59+
DIALOG_TYPE = List[MESSAGE_TYPE]
60+
5261
def __init__(self, tokenizer):
5362
self.tokenizer = tokenizer
5463

5564
@abstractmethod
56-
def encode_dialog_prompt(self, dialog) -> List[int]:
57-
raise NotImplementedError()
65+
def encode_dialog_prompt(
66+
self,
67+
dialog: DIALOG_TYPE,
68+
add_generation_prompt: bool,
69+
) -> List[int]:
70+
"""Encode a sequence of messages into a sequence of token IDs, including
71+
the chat template
72+
73+
Args:
74+
dialog (DIALOG_TYPE): The sequence of dialog messages to encode.
75+
This will be the additional messages on top of those that have
76+
already been processed.
77+
add_generation_prompt (bool): Whether to include a generation prompt
78+
at the end of the encoded sequence.
79+
80+
Returns:
81+
List[int]: A list of token IDs representing the encoded prompt.
82+
"""
5883

5984

6085
class Llama3ChatFormatter(_ChatFormatter):
@@ -64,16 +89,16 @@ class Llama3ChatFormatter(_ChatFormatter):
6489
6590
"""
6691

67-
def encode_header(self, role) -> List[int]:
92+
def _encode_header(self, role) -> List[int]:
6893
tokens = []
6994
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
7095
tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
7196
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
7297
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
7398
return tokens
7499

75-
def encode_message(self, message) -> List[int]:
76-
tokens = self.encode_header(message["role"])
100+
def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]:
101+
tokens = self._encode_header(message["role"])
77102
if isinstance(message["content"], str):
78103
tokens.extend(
79104
self.tokenizer.encode(message["content"], bos=False, eos=False)
@@ -88,54 +113,79 @@ def encode_message(self, message) -> List[int]:
88113
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
89114
return tokens
90115

91-
def encode_dialog_prompt(self, dialog) -> List[int]:
116+
def encode_dialog_prompt(
117+
self,
118+
dialog: _ChatFormatter.DIALOG_TYPE,
119+
add_generation_prompt: bool,
120+
) -> List[int]:
92121
tokens = []
93122
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
94123
for message in dialog:
95-
tokens.extend(self.encode_message(message))
124+
tokens.extend(self._encode_message(message))
96125
# Add the start of an assistant message for the model to complete.
97-
tokens.extend(self.encode_header("assistant")) # Pass role directly as a string
126+
if add_generation_prompt:
127+
tokens.extend(self._encode_header("assistant")) # Pass role directly as a string
98128
return tokens
99129

100130

101-
B_INST, E_INST = "[INST]", "[/INST]"
102-
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
131+
class Llama2ChatFormatter(_ChatFormatter):
132+
"""
133+
Chat formatting for Llama2
134+
CITE: https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
135+
"""
136+
137+
B_INST, E_INST = "[INST] ", " [/INST]"
138+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
103139

140+
@staticmethod
141+
def _get_content_str(message: _ChatFormatter.MESSAGE_TYPE) -> str:
142+
if isinstance(message["content"], list):
143+
return message["content"][0]["text"]
144+
return message["content"]
104145

105-
class Llama2ChatFormatter(_ChatFormatter):
106-
def encode_dialog_prompt(self, dialog) -> List[int]:
107-
tokens = self.tokenizer.encode(f"{B_INST} ")
108-
first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
146+
def encode_dialog_prompt(
147+
self,
148+
dialog: _ChatFormatter.DIALOG_TYPE,
149+
add_generation_prompt: bool, # UNUSED
150+
) -> List[int]:
151+
new_turn = True
152+
tokens = []
109153
for message in dialog:
110-
if isinstance(message["content"], list):
111-
content = message["content"][0]["text"]
154+
if new_turn:
155+
tokens += self.tokenizer.encode(f"{self.tokenizer.bos}{self.B_INST}")
156+
content = self._get_content_str(message).strip()
157+
role = message["role"]
158+
if role == "system":
159+
tokens += self.tokenizer.encode(f"{self.B_SYS}{content}{self.E_SYS}")
160+
new_turn = False
161+
elif role == "user":
162+
tokens += self.tokenizer.encode(f"{content}{self.E_INST}")
163+
new_turn = False
164+
elif role == "assistant":
165+
tokens += self.tokenizer.encode(f" {content} {self.tokenizer.eos}\n")
166+
new_turn = True
112167
else:
113-
content = message["content"]
114-
content = content.strip()
115-
if message["role"] == "system":
116-
encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
117-
first_message = False
118-
elif message["role"] == "user":
119-
encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
120-
f"{B_INST if first_message else ''} {content} {E_INST} "
121-
)
122-
first_message = True
123-
elif message["role"] == "assistant":
124-
encoded = self.tokenizer.encode(f"{content}\n\n") + [
125-
self.tokenizer.eos_id()
126-
]
127-
tokens += encoded
168+
raise ValueError("Invalid role in dialog.")
128169
return tokens
129170

130171

172+
131173
class HFTokenizerChatFormatter(_ChatFormatter):
132174
"""Chat formatter that uses the built-in formatting capabilities of an HF
133175
tokenizer instance
134176
"""
135-
def encode_dialog_prompt(self, dialog) -> List[int]:
136-
rendered = self.tokenizer.apply_chat_template(dialog, add_generation_prompt=True)
177+
def encode_dialog_prompt(
178+
self,
179+
dialog: _ChatFormatter.DIALOG_TYPE,
180+
add_generation_prompt: bool,
181+
) -> List[int]:
182+
rendered = self.tokenizer.apply_chat_template(
183+
dialog, add_generation_prompt=add_generation_prompt
184+
)
185+
logger.debug("Formatted chat prompt:\n%s", rendered)
137186
return self.tokenizer.encode(rendered)
138187

188+
## Generation ##################################################################
139189

140190
@dataclass
141191
class GeneratorArgs:
@@ -1040,38 +1090,21 @@ def chat(
10401090
if prompt == "/bye":
10411091
print("Exiting Chat.\n")
10421092
break
1043-
if not self.is_llama3_model:
1044-
if self.system_prompt:
1045-
prompt = f"{B_INST} {B_SYS}\n{self.system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip()} {E_INST}"
1046-
self.system_prompt = (
1047-
None # can only provide system prompt on first interaction
1048-
)
1049-
else:
1050-
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
1051-
encoded = self.encode_tokens(
1052-
prompt, bos=self.model.config.tokenizer_prepend_bos, device=self.builder_args.device
1053-
)
1054-
else:
1055-
if self.system_prompt:
1056-
encoded = self.chat_formatter.encode_dialog_prompt(
1057-
[
1058-
{"role": "system", "content": self.system_prompt},
1059-
{"role": "user", "content": prompt},
1060-
]
1061-
)
1062-
self.system_prompt = None
1063-
elif is_first_sample:
1064-
encoded = self.chat_formatter.encode_dialog_prompt(
1065-
[{"role": "user", "content": prompt}]
1066-
)
1067-
else:
1068-
encoded = self.chat_formatter.encode_message(
1069-
{"role": "user", "content": prompt}
1070-
)
1071-
encoded.extend(self.chat_formatter.encode_header("assistant"))
1072-
encoded = torch.tensor(
1073-
encoded, dtype=torch.int, device=self.builder_args.device
1093+
1094+
# Encode the additional messages added in this dialog turn. If
1095+
# this is the first turn, that includes any system prompt.
1096+
messages_to_encode = []
1097+
if is_first_sample and self.system_prompt:
1098+
messages_to_encode.append(
1099+
{"role": "system", "content": self.system_prompt}
10741100
)
1101+
messages_to_encode.append({"role": "system", "content": prompt})
1102+
encoded = self.chat_formatter.encode_dialog_prompt(
1103+
messages_to_encode, add_generation_prompt=True,
1104+
)
1105+
encoded = torch.tensor(
1106+
encoded, dtype=torch.int, device=self.builder_args.device
1107+
)
10751108
if encoded.size(0) + start_pos > max_seq_length:
10761109
print(
10771110
"This prompt would take us past the max_seq_length. Ending Conversation."

0 commit comments

Comments
 (0)