Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ee35cd4

Browse files
authored
Merge branch 'main' into newtemplate
2 parents 23f3c49 + 083fdaf commit ee35cd4

File tree

6 files changed

+105
-21
lines changed

6 files changed

+105
-21
lines changed

.github/workflows/more-tests.yml

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
gpu-arch-version: "12.4"
2020
timeout: 60
2121
script: |
22+
set -xeou pipefail
2223
echo "::group::Print machine info"
2324
uname -a
2425
echo "::endgroup::"
@@ -39,9 +40,10 @@ jobs:
3940
echo "::endgroup::"
4041
4142
echo "::group::Run inference"
42-
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
43+
export MODEL_DIR=checkpoints/stories15M/
44+
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
4345
export MODEL_NAME=stories15M
44-
export MODEL_DIR=/tmp
46+
4547
4648
for DTYPE in bfloat16 float16 float32; do
4749
###################################################################
@@ -83,3 +85,66 @@ jobs:
8385
echo "tests complete"
8486
echo "******************************************"
8587
echo "::endgroup::"
88+
89+
90+
test-sdpa-backends-export:
91+
permissions:
92+
id-token: write
93+
contents: read
94+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
95+
with:
96+
runner: linux.g5.4xlarge.nvidia.gpu
97+
gpu-arch-type: cuda
98+
gpu-arch-version: "12.4"
99+
timeout: 60
100+
script: |
101+
set -xeou pipefail
102+
echo "::group::Print machine info"
103+
uname -a
104+
echo "::endgroup::"
105+
106+
echo "::group::Download checkpoints"
107+
# Install requirements
108+
./install/install_requirements.sh cuda
109+
pip3 list
110+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
111+
echo "::endgroup::"
112+
113+
echo "::group::Download checkpoints"
114+
mkdir -p checkpoints/stories15M
115+
pushd checkpoints/stories15M
116+
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
117+
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
118+
popd
119+
echo "::endgroup::"
120+
121+
echo "::group::Run inference"
122+
export MODEL_DIR=checkpoints/stories15M/
123+
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
124+
export MODEL_NAME=stories15M
125+
126+
./torchchat/utils/scripts/build_native.sh aoti
127+
128+
for DEVICE in cpu cuda; do
129+
# depending on how the parameter passing works, may only be able to do bfloat16 for aoti_run, similar to runner-cuda-dtype.yml
130+
# (although the runner environment should not have an opinion what we us in the artifact, and we might suitably abstract that)
131+
for DTYPE in bfloat16 float16 float32; do
132+
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
133+
echo "***************************************************************"
134+
echo "*** $DEVICE $DTYPE $SDPA"
135+
###################################################################
136+
# Export DSO and run with Python
137+
python torchchat.py export --output-dso dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
138+
python torchchat.py generate --dso-path dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --prompt "Once upon a time"
139+
###################################################################
140+
# Export AOTI and run with aoti_run
141+
python torchchat.py export --output-aoti /tmp/model.pt2 --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
142+
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "Once upon a time"
143+
###################################################################
144+
done
145+
done
146+
done
147+
148+
echo "tests complete"
149+
echo "******************************************"
150+
echo "::endgroup::"

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
torchchat is a small codebase showcasing the ability to run large language models (LLMs) seamlessly. With torchchat, you can run LLMs using Python, within your own (C/C++) application (desktop or server) and on iOS and Android.
44

55
> [!IMPORTANT]
6-
> Update September 25, 2024: torchchat has multimodal support for **Llama3.2 11B**!!
6+
> Update
7+
>
8+
> **February 3, 2025**: torchchat has support for [**DeepSeek R1 Distill: 8B**]( https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B)!
9+
>
10+
> **September 25, 2024**: torchchat has multimodal support for **Llama3.2 11B**!
711
>
812
> To try it out, finish the [Installation](#Installation) section below, then hop
913
> over to our [multimodal guide](docs/multimodal.md) to learn more.
@@ -75,6 +79,7 @@ aliases.
7579
| [ibm-granite/granite-3.0-8b-instruct](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct) || Alias to `granite3-8b`.|
7680
| [ibm-granite/granite-3.1-2b-instruct](https://huggingface.co/ibm-granite/granite-3.1-2b-instruct) || Alias to `granite3.1-2b` and `granite3.1`.|
7781
| [ibm-granite/granite-3.1-8b-instruct](https://huggingface.co/ibm-granite/granite-3.1-8b-instruct) || Alias to `granite3.1-8b`.|
82+
| [deepseek-ai/DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) || Alias to `deepseek-r1:8b`.|
7883

7984

8085
## Installation

tokenizer/hf_tokenizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ def __init__(self, file_path: str):
4646
if tokenizer_config_path is not None:
4747
with open(tokenizer_config_path, "r") as handle:
4848
tok_config = json.load(handle)
49-
bos_token = tok_config.get("bos_token")
50-
eos_token = tok_config.get("eos_token")
49+
50+
def _extract_token(identifier: str) -> Optional[str]:
51+
entry: Optional[Union[str, dict]] = tok_config.get(identifier)
52+
return entry.get("content") if isinstance(entry, dict) else entry
53+
54+
bos_token = _extract_token("bos_token")
55+
eos_token = _extract_token("eos_token")
56+
5157
if bos_token is not None:
5258
self._bos_id = self._tokenizer.token_to_id(bos_token)
5359
if eos_token is not None:

torchchat/export.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,14 @@ def main(args):
490490
print(
491491
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
492492
)
493-
export_for_server(
494-
model_to_dso,
495-
builder_args.device,
496-
output_dso_path,
497-
builder_args.dynamic_shapes,
498-
package=False,
499-
)
493+
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
494+
export_for_server(
495+
model_to_dso,
496+
builder_args.device,
497+
output_dso_path,
498+
builder_args.dynamic_shapes,
499+
package=False,
500+
)
500501

501502
if output_aoti_package_path:
502503
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
@@ -512,14 +513,15 @@ def main(args):
512513
print(
513514
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
514515
)
515-
export_for_server(
516-
model_to_aoti_package,
517-
builder_args.device,
518-
output_aoti_package_path,
519-
builder_args.dynamic_shapes,
520-
package=True,
521-
metadata=metadata,
522-
)
516+
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
517+
export_for_server(
518+
model_to_aoti_package,
519+
builder_args.device,
520+
output_aoti_package_path,
521+
builder_args.dynamic_shapes,
522+
package=True,
523+
metadata=metadata,
524+
)
523525

524526
if output_snapshot_path:
525527
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
@@ -529,4 +531,3 @@ def main(args):
529531
builder_args.device,
530532
output_snapshot_path,
531533
)
532-

torchchat/model_config/models.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@
5151
"distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
5252
"transformer_params_key": "Meta-Llama-3.1-8B"
5353
},
54+
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B": {
55+
"aliases": ["deepseek-r1:8b"],
56+
"distribution_channel": "HuggingFaceSnapshot",
57+
"distribution_path": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
58+
"tokenizer_file": "tokenizer.json"
59+
},
5460
"meta-llama/Meta-Llama-3.1-70B-Instruct": {
5561
"aliases": ["llama3.1-70b"],
5662
"distribution_channel": "HuggingFaceSnapshot",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"block_size": 131072, "dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "use_hf_tokenizer": true, "norm_eps": 1e-05, "rope_scaling": {"factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192}}

0 commit comments

Comments
 (0)