Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 069d393

Browse files
committed
merge
2 parents 5e0b073 + 019f76f commit 069d393

27 files changed

+1272
-1081
lines changed

docs/quantization.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,22 @@ To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental
142142

143143
From the torchchat root directory, run
144144
```
145-
sh torchchat/utils/scripts/build_torchao_ops.sh
145+
bash torchchat/utils/scripts/build_torchao_ops.sh
146146
```
147147

148148
This should take about 10 seconds to complete.
149149

150150
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
151151

152152
```
153-
sh torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
153+
bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
154154
```
155155

156156
```
157-
sh torchchat/utils/scripts/build_native.sh et link_torchao_ops
157+
bash torchchat/utils/scripts/build_native.sh et link_torchao_ops
158158
```
159159

160-
Note before running `sh torchchat/utils/scripts/build_native.sh et link_torchao_ops`, you must first install executorch with `sh torchchat/utils/scripts/install_et.sh` if you have not done so already.
160+
Note before running `bash torchchat/utils/scripts/build_native.sh et link_torchao_ops`, you must first install executorch with `bash torchchat/utils/scripts/install_et.sh` if you have not done so already.
161161

162162
### Examples
163163

@@ -212,7 +212,7 @@ Currently, torchchat can only run them on Eager mode.
212212

213213
From the torchchat root directory, run
214214
```
215-
sh torchchat/utils/scripts/build_torchao_ops.sh mps
215+
bash torchchat/utils/scripts/build_torchao_ops.sh mps
216216
```
217217

218218
### Examples

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7d7c14e898eca3fe66138d2a9445755a9270b800
1+
2e032c6b0de960dee554dcb08126ace718b14c6d

install/install_requirements.sh

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,37 +44,25 @@ fi
4444

4545
echo "Using pip executable: $PIP_EXECUTABLE"
4646

47-
#
48-
# First install requirements in install/requirements.txt. Older torch may be
49-
# installed from the dependency of other models. It will be overridden by
50-
# newer version of torch nightly installed later in this script.
51-
#
52-
53-
(
54-
set -x
55-
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cu121
56-
)
57-
5847
# Since torchchat often uses main-branch features of pytorch, only the nightly
5948
# pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should
6049
# agree with the third-party/pytorch pinned submodule commit.
6150
#
6251
# NOTE: If a newly-fetched version of the executorch repo changes the value of
6352
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
6453
# package versions.
65-
<<<<<<< HEAD
6654
if [[ -x "$(command -v xpu-smi)" ]];
6755
then
68-
PYTORCH_NIGHTLY_VERSION=dev20241212
56+
PYTORCH_NIGHTLY_VERSION=dev20241217
6957
else
70-
PYTORCH_NIGHTLY_VERSION=dev20241213
58+
PYTORCH_NIGHTLY_VERSION=dev20241218
7159
fi
7260

7361
# Nightly version for torchvision
74-
VISION_NIGHTLY_VERSION=dev20241213
62+
VISION_NIGHTLY_VERSION=dev20241218
7563

7664
# Nightly version for torchtune
77-
TUNE_NIGHTLY_VERSION=dev20241126
65+
TUNE_NIGHTLY_VERSION=dev20241218
7866

7967
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
8068
(
@@ -99,7 +87,6 @@ else
9987
fi
10088

10189
# pip packages needed by exir.
102-
<<<<<<< HEAD
10390
if [[ -x "$(command -v xpu-smi)" ]];
10491
then
10592
REQUIREMENTS_TO_INSTALL=(
@@ -115,6 +102,16 @@ else
115102
)
116103
fi
117104

105+
#
106+
# First install requirements in install/requirements.txt. Older torch may be
107+
# installed from the dependency of other models. It will be overridden by
108+
# newer version of torch nightly installed later in this script.
109+
#
110+
(
111+
set -x
112+
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url "${TORCH_NIGHTLY_URL}"
113+
)
114+
118115
# Install the requirements. --extra-index-url tells pip to look for package
119116
# versions on the provided URL if they aren't available on the default URL.
120117
(

install/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ gguf
99
# Tiktoken tokenizer for Llama 3 and other advanced models
1010
tiktoken
1111

12+
# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
13+
tokenizers
14+
jinja2
15+
1216
# Miscellaneous
1317
snakeviz
1418
sentencepiece

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
Global pytest config, fixtures, and helpers go here!
3+
"""
4+
5+
# Standard
6+
import os
7+
import sys
8+
9+
# Make sure tests can import torchchat
10+
sys.path.append(
11+
os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
12+
)

tests/test_chat_formatters.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""
2+
Unit tests for chat formatters
3+
"""
4+
5+
# Third Party
6+
import pytest
7+
8+
# Local
9+
from torchchat.generate import (
10+
HFTokenizerChatFormatter,
11+
Llama2ChatFormatter,
12+
Llama3ChatFormatter,
13+
)
14+
15+
## Helpers #####################################################################
16+
17+
class DummyTokenizer:
18+
"""Dummy tokenizer that encodes as strings so it's easy to check formatting"""
19+
def encode(self, text, *_, **__):
20+
return text
21+
22+
23+
class DummySPTokenizer(DummyTokenizer):
24+
"""Emulated Sentencepiece tokenizer with bos/eos"""
25+
bos = "<s>"
26+
eos = "</s>"
27+
28+
29+
class DummyLlama3Tokenizer(DummyTokenizer):
30+
class _IdentityDict:
31+
def __getitem__(self, key):
32+
return key
33+
special_tokens = _IdentityDict()
34+
35+
36+
class DummyHFTokenizer(DummyTokenizer):
37+
"""Dummy made up chat template scheme"""
38+
# Sequence
39+
bos = "<bos>"
40+
# Turn
41+
bot = "<bot>"
42+
eot = "<eot>"
43+
# Role
44+
bor = "<bor>"
45+
eor = "<eor>"
46+
def apply_chat_template(self, messages, add_generation_prompt):
47+
out = [self.bos]
48+
role = None
49+
for msg in messages:
50+
role = msg["role"]
51+
content = msg["content"]
52+
out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
53+
if add_generation_prompt and role != "assistant":
54+
out.append(f"{self.bot}{self.bor}assistant{self.eor}")
55+
return "\n".join(out)
56+
57+
58+
def check_rendering(fmt, messages, expected, add_generation_prompt):
59+
"""Render messages and compare to expected output"""
60+
assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected
61+
62+
63+
def make_message(role, text):
64+
return {"role": role, "content": text}
65+
66+
67+
SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
68+
USER1 = "Hello world!"
69+
ASSISTANT1 = "Greetings! How can I help you?"
70+
USER2 = "Why is the sky blue?"
71+
ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."
72+
73+
74+
# Stock sets of messages to test
75+
MSGS_NO_SYS= [
76+
make_message("user", USER1),
77+
]
78+
MSGS_SYS_USR = [
79+
make_message("system", SYSTEM_PROMPT),
80+
make_message("user", USER1),
81+
]
82+
MSGS_SYS_USR_ASST = [
83+
make_message("system", SYSTEM_PROMPT),
84+
make_message("user", USER1),
85+
make_message("assistant", ASSISTANT1),
86+
]
87+
MSGS_MULTI_TURN = [
88+
make_message("system", SYSTEM_PROMPT),
89+
make_message("user", USER1),
90+
make_message("assistant", ASSISTANT1),
91+
make_message("user", USER2),
92+
make_message("assistant", ASSISTANT2),
93+
]
94+
95+
## Llama2ChatFormatter #########################################################
96+
97+
@pytest.mark.parametrize(
98+
["messages", "expected"],
99+
[
100+
# single user message (no system prompt)
101+
(MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"),
102+
# sys, usr
103+
(MSGS_SYS_USR, f"""<s>[INST] <<SYS>>
104+
{SYSTEM_PROMPT}
105+
<</SYS>>
106+
107+
{USER1} [/INST]"""),
108+
# sys, usr, asst
109+
(MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>>
110+
{SYSTEM_PROMPT}
111+
<</SYS>>
112+
113+
{USER1} [/INST] {ASSISTANT1} </s>
114+
"""),
115+
# sys, usr, asst, usr, asst
116+
(MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>>
117+
{SYSTEM_PROMPT}
118+
<</SYS>>
119+
120+
{USER1} [/INST] {ASSISTANT1} </s>
121+
<s>[INST] {USER2} [/INST] {ASSISTANT2} </s>
122+
"""),
123+
]
124+
)
125+
def test_llama2_chat_formatter(messages, expected):
126+
"""Tests for Llama2 following the official guide
127+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
128+
"""
129+
tok = DummySPTokenizer()
130+
fmt = Llama2ChatFormatter(tok)
131+
# NOTE: add_generation_prompt not used by Llama2
132+
check_rendering(fmt, messages, expected, True)
133+
134+
## Llama3ChatFormatter #########################################################
135+
136+
@pytest.mark.parametrize(
137+
["messages", "expected"],
138+
[
139+
# single user message (no system prompt)
140+
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
141+
142+
{USER1}<|eot_id|>"""),
143+
# sys, usr
144+
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
145+
146+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
147+
148+
{USER1}<|eot_id|>"""),
149+
# sys, usr, asst
150+
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
151+
152+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
153+
154+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
155+
156+
{ASSISTANT1}<|eot_id|>"""),
157+
# sys, usr, asst, usr, asst
158+
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
159+
160+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
161+
162+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
163+
164+
{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>
165+
166+
{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
167+
168+
{ASSISTANT2}<|eot_id|>"""),
169+
]
170+
)
171+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
172+
def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
173+
"""Tests for Llama3 following the official guide
174+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
175+
"""
176+
tok = DummyLlama3Tokenizer()
177+
fmt = Llama3ChatFormatter(tok)
178+
# No assistant prompt added if the last message is from the assistant
179+
if add_generation_prompt and messages[-1]["role"] != "assistant":
180+
expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
181+
check_rendering(fmt, messages, expected, add_generation_prompt)
182+
183+
## HFTokenizerChatFormatter ####################################################
184+
185+
@pytest.mark.parametrize(
186+
["messages", "expected"],
187+
[
188+
# single user message (no system prompt)
189+
(MSGS_NO_SYS, f"""<bos>
190+
<bot><bor>user<eor>{USER1}<eot>"""),
191+
# sys, usr
192+
(MSGS_SYS_USR, f"""<bos>
193+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
194+
<bot><bor>user<eor>{USER1}<eot>"""),
195+
# sys, usr, asst
196+
(MSGS_SYS_USR_ASST, f"""<bos>
197+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
198+
<bot><bor>user<eor>{USER1}<eot>
199+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""),
200+
# sys, usr, asst, usr, asst
201+
(MSGS_MULTI_TURN, f"""<bos>
202+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
203+
<bot><bor>user<eor>{USER1}<eot>
204+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>
205+
<bot><bor>user<eor>{USER2}<eot>
206+
<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""),
207+
]
208+
)
209+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
210+
def test_hf_chat_formatter(messages, expected, add_generation_prompt):
211+
tok = DummyHFTokenizer()
212+
fmt = HFTokenizerChatFormatter(tok)
213+
# No assistant prompt added if the last message is from the assistant
214+
if add_generation_prompt and messages[-1]["role"] != "assistant":
215+
expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
216+
check_rendering(fmt, messages, expected, add_generation_prompt)

0 commit comments

Comments
 (0)