Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ae70308

Browse files
authored
Merge branch 'pytorch:main' into patch-30
2 parents dba437e + cc0ffce commit ae70308

File tree

17 files changed

+1061
-1027
lines changed

17 files changed

+1061
-1027
lines changed

install/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ gguf
99
# Tiktoken tokenizer for Llama 3 and other advanced models
1010
tiktoken
1111

12+
# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
13+
tokenizers
14+
jinja2
15+
1216
# Miscellaneous
1317
snakeviz
1418
sentencepiece

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
Global pytest config, fixtures, and helpers go here!
3+
"""
4+
5+
# Standard
6+
import os
7+
import sys
8+
9+
# Make sure tests can import torchchat
10+
sys.path.append(
11+
os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
12+
)

tests/test_chat_formatters.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""
2+
Unit tests for chat formatters
3+
"""
4+
5+
# Third Party
6+
import pytest
7+
8+
# Local
9+
from torchchat.generate import (
10+
HFTokenizerChatFormatter,
11+
Llama2ChatFormatter,
12+
Llama3ChatFormatter,
13+
)
14+
15+
## Helpers #####################################################################
16+
17+
class DummyTokenizer:
18+
"""Dummy tokenizer that encodes as strings so it's easy to check formatting"""
19+
def encode(self, text, *_, **__):
20+
return text
21+
22+
23+
class DummySPTokenizer(DummyTokenizer):
24+
"""Emulated Sentencepiece tokenizer with bos/eos"""
25+
bos = "<s>"
26+
eos = "</s>"
27+
28+
29+
class DummyLlama3Tokenizer(DummyTokenizer):
30+
class _IdentityDict:
31+
def __getitem__(self, key):
32+
return key
33+
special_tokens = _IdentityDict()
34+
35+
36+
class DummyHFTokenizer(DummyTokenizer):
37+
"""Dummy made up chat template scheme"""
38+
# Sequence
39+
bos = "<bos>"
40+
# Turn
41+
bot = "<bot>"
42+
eot = "<eot>"
43+
# Role
44+
bor = "<bor>"
45+
eor = "<eor>"
46+
def apply_chat_template(self, messages, add_generation_prompt):
47+
out = [self.bos]
48+
role = None
49+
for msg in messages:
50+
role = msg["role"]
51+
content = msg["content"]
52+
out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
53+
if add_generation_prompt and role != "assistant":
54+
out.append(f"{self.bot}{self.bor}assistant{self.eor}")
55+
return "\n".join(out)
56+
57+
58+
def check_rendering(fmt, messages, expected, add_generation_prompt):
59+
"""Render messages and compare to expected output"""
60+
assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected
61+
62+
63+
def make_message(role, text):
64+
return {"role": role, "content": text}
65+
66+
67+
SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
68+
USER1 = "Hello world!"
69+
ASSISTANT1 = "Greetings! How can I help you?"
70+
USER2 = "Why is the sky blue?"
71+
ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."
72+
73+
74+
# Stock sets of messages to test
75+
MSGS_NO_SYS= [
76+
make_message("user", USER1),
77+
]
78+
MSGS_SYS_USR = [
79+
make_message("system", SYSTEM_PROMPT),
80+
make_message("user", USER1),
81+
]
82+
MSGS_SYS_USR_ASST = [
83+
make_message("system", SYSTEM_PROMPT),
84+
make_message("user", USER1),
85+
make_message("assistant", ASSISTANT1),
86+
]
87+
MSGS_MULTI_TURN = [
88+
make_message("system", SYSTEM_PROMPT),
89+
make_message("user", USER1),
90+
make_message("assistant", ASSISTANT1),
91+
make_message("user", USER2),
92+
make_message("assistant", ASSISTANT2),
93+
]
94+
95+
## Llama2ChatFormatter #########################################################
96+
97+
@pytest.mark.parametrize(
98+
["messages", "expected"],
99+
[
100+
# single user message (no system prompt)
101+
(MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"),
102+
# sys, usr
103+
(MSGS_SYS_USR, f"""<s>[INST] <<SYS>>
104+
{SYSTEM_PROMPT}
105+
<</SYS>>
106+
107+
{USER1} [/INST]"""),
108+
# sys, usr, asst
109+
(MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>>
110+
{SYSTEM_PROMPT}
111+
<</SYS>>
112+
113+
{USER1} [/INST] {ASSISTANT1} </s>
114+
"""),
115+
# sys, usr, asst, usr, asst
116+
(MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>>
117+
{SYSTEM_PROMPT}
118+
<</SYS>>
119+
120+
{USER1} [/INST] {ASSISTANT1} </s>
121+
<s>[INST] {USER2} [/INST] {ASSISTANT2} </s>
122+
"""),
123+
]
124+
)
125+
def test_llama2_chat_formatter(messages, expected):
126+
"""Tests for Llama2 following the official guide
127+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
128+
"""
129+
tok = DummySPTokenizer()
130+
fmt = Llama2ChatFormatter(tok)
131+
# NOTE: add_generation_prompt not used by Llama2
132+
check_rendering(fmt, messages, expected, True)
133+
134+
## Llama3ChatFormatter #########################################################
135+
136+
@pytest.mark.parametrize(
137+
["messages", "expected"],
138+
[
139+
# single user message (no system prompt)
140+
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
141+
142+
{USER1}<|eot_id|>"""),
143+
# sys, usr
144+
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
145+
146+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
147+
148+
{USER1}<|eot_id|>"""),
149+
# sys, usr, asst
150+
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
151+
152+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
153+
154+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
155+
156+
{ASSISTANT1}<|eot_id|>"""),
157+
# sys, usr, asst, usr, asst
158+
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
159+
160+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
161+
162+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
163+
164+
{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>
165+
166+
{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
167+
168+
{ASSISTANT2}<|eot_id|>"""),
169+
]
170+
)
171+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
172+
def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
173+
"""Tests for Llama3 following the official guide
174+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
175+
"""
176+
tok = DummyLlama3Tokenizer()
177+
fmt = Llama3ChatFormatter(tok)
178+
# No assistant prompt added if the last message is from the assistant
179+
if add_generation_prompt and messages[-1]["role"] != "assistant":
180+
expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
181+
check_rendering(fmt, messages, expected, add_generation_prompt)
182+
183+
## HFTokenizerChatFormatter ####################################################
184+
185+
@pytest.mark.parametrize(
186+
["messages", "expected"],
187+
[
188+
# single user message (no system prompt)
189+
(MSGS_NO_SYS, f"""<bos>
190+
<bot><bor>user<eor>{USER1}<eot>"""),
191+
# sys, usr
192+
(MSGS_SYS_USR, f"""<bos>
193+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
194+
<bot><bor>user<eor>{USER1}<eot>"""),
195+
# sys, usr, asst
196+
(MSGS_SYS_USR_ASST, f"""<bos>
197+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
198+
<bot><bor>user<eor>{USER1}<eot>
199+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""),
200+
# sys, usr, asst, usr, asst
201+
(MSGS_MULTI_TURN, f"""<bos>
202+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
203+
<bot><bor>user<eor>{USER1}<eot>
204+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>
205+
<bot><bor>user<eor>{USER2}<eot>
206+
<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""),
207+
]
208+
)
209+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
210+
def test_hf_chat_formatter(messages, expected, add_generation_prompt):
211+
tok = DummyHFTokenizer()
212+
fmt = HFTokenizerChatFormatter(tok)
213+
# No assistant prompt added if the last message is from the assistant
214+
if add_generation_prompt and messages[-1]["role"] != "assistant":
215+
expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
216+
check_rendering(fmt, messages, expected, add_generation_prompt)

tokenizer/hf_tokenizer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Standard
8-
from typing import List, Optional
8+
from typing import Dict, List, Optional
99
import json
1010
import os
1111

1212
# Third Party
13+
import jinja2
1314
from tokenizers import Tokenizer
1415

1516
# Local
@@ -37,6 +38,9 @@ def __init__(self, file_path: str):
3738
# Load the tokenizer itself
3839
self._tokenizer = Tokenizer.from_file(tokenizer_path)
3940

41+
# Load the chat template if we have a config path
42+
self._chat_template: Optional[jinja2.Template] = None
43+
4044
# If available, parse bos/eos tokens from the tokenizer config
4145
self._bos_id, self._eos_id = None, None
4246
if tokenizer_config_path is not None:
@@ -48,6 +52,8 @@ def __init__(self, file_path: str):
4852
self._bos_id = self._tokenizer.token_to_id(bos_token)
4953
if eos_token is not None:
5054
self._eos_id = self._tokenizer.token_to_id(eos_token)
55+
if chat_template_str := tok_config.get("chat_template"):
56+
self._chat_template = jinja2.Template(chat_template_str)
5157

5258
# If no eos/bos tokens found, go looking for them!
5359
if None in [self._bos_id, self._eos_id]:
@@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
7076
if len(candidate_toks) == 1:
7177
return candidate_toks[0]["id"]
7278

79+
## Interface ##
80+
7381
def encode(
7482
self,
7583
s: str,
@@ -90,3 +98,21 @@ def bos_id(self) -> int:
9098

9199
def eos_id(self) -> int:
92100
return self._eos_id
101+
102+
## Additional Public Methods ##
103+
104+
def has_chat_template(self) -> bool:
105+
return bool(self._chat_template)
106+
107+
def apply_chat_template(
108+
self,
109+
dialog: List[Dict[str, str]],
110+
add_generation_prompt: bool = False,
111+
) -> str:
112+
"""If configured with a chat template, apply it to the list of messages
113+
"""
114+
if not self._chat_template:
115+
raise ValueError("No chat template configured!")
116+
return self._chat_template.render(
117+
messages=dialog, add_generation_prompt=add_generation_prompt
118+
)

0 commit comments

Comments
 (0)