Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 0f28976

Browse files
authored
Merge branch 'main' into lessw2020/prefill
2 parents 1ea7960 + 9c47edc commit 0f28976

File tree

7 files changed

+79
-45
lines changed

7 files changed

+79
-45
lines changed

README.md

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ torchchat is a small codebase showcasing the ability to run large language model
1010
- [Run chat in the Browser](#browser)
1111
- [Run models on desktop/server without python](#desktopserver-execution)
1212
- [Use AOT Inductor for faster execution](#aoti-aot-inductor)
13-
- [Running in c++ using the runner](#running-native-using-our-c-runner)
13+
- [Running in c++ using the runner](#run-using-our-c-runner)
1414
- [Run models on mobile](#mobile-execution)
1515
- [Deploy and run on iOS](#deploy-and-run-on-ios)
1616
- [Deploy and run on Android](#deploy-and-run-on-android)
@@ -33,7 +33,8 @@ torchchat is a small codebase showcasing the ability to run large language model
3333
## Installation
3434
The following steps require that you have [Python 3.10](https://www.python.org/downloads/release/python-3100/) installed.
3535

36-
*torchchat uses the latest changes from various PyTorch projects so it's highly recommended that you use a venv (by using the commands below) or CONDA.*
36+
> [!TIP]
37+
> torchchat uses the latest changes from various PyTorch projects so it's highly recommended that you use a venv (by using the commands below) or CONDA.
3738
3839
[skip default]: begin
3940
```bash
@@ -127,21 +128,21 @@ python3 torchchat.py download llama3.1
127128
<summary>Additional Model Inventory Management Commands</summary>
128129

129130
### List
130-
This subcommands shows the available models
131+
This subcommand shows the available models
131132
```bash
132133
python3 torchchat.py list
133134
```
134135

135136
### Where
136-
This subcommands shows location of a particular model.
137+
This subcommand shows location of a particular model.
137138
```bash
138139
python3 torchchat.py where llama3.1
139140
```
140141
This is useful in scripts when you do not want to hard-code paths
141142

142143

143144
### Remove
144-
This subcommands removes the specified model
145+
This subcommand removes the specified model
145146
```bash
146147
python3 torchchat.py remove llama3.1
147148
```
@@ -181,18 +182,10 @@ python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy an
181182
[skip default]: end
182183

183184
### Server
184-
**Note: This feature is still a work in progress and not all endpoints are working**
185-
186-
187-
<details>
188-
<summary>This mode gives a REST API that matches the OpenAI API spec for interacting with a model</summary>
189-
185+
This mode exposes a REST API for interacting with a model.
190186
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
191-
Since this feature is under active development, not every parameter is consumed. See api/api.py for details on
192-
which request parameters are implemented. If you encounter any issues, please comment on the [tracking Github issue](https://github.com/pytorch/torchchat/issues/973).
193187

194188
To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
195-
196189
In one terminal, start the server
197190

198191
[skip default]: begin
@@ -204,8 +197,14 @@ python3 torchchat.py server llama3.1
204197

205198
In another terminal, query the server using `curl`. Depending on the model configuration, this query might take a few minutes to respond.
206199

207-
Setting `stream` to "true" in the request emits a response in chunks. If `stream` is unset or not "true", then the client will await the full response from the server.
200+
> [!NOTE]
201+
> Since this feature is under active development, not every parameter is consumed. See api/api.py for details on
202+
> which request parameters are implemented. If you encounter any issues, please comment on the [tracking Github issue](https://github.com/pytorch/torchchat/issues/973).
208203
204+
<details>
205+
<summary>Example Query</summary>
206+
207+
Setting `stream` to "true" in the request emits a response in chunks. If `stream` is unset or not "true", then the client will await the full response from the server.
209208

210209
**Example Input + Output**
211210

@@ -348,7 +347,7 @@ Specifically there are 2 ways of doing so: Pure Python and via a Runner
348347

349348
```
350349
# Execute
351-
python3 torchchat.py generate llama3.1 --device cpu --pte-path llama3.1.pte --prompt "Hello my name is"
350+
python3 torchchat.py generate llama3.1 --pte-path llama3.1.pte --prompt "Hello my name is"
352351
```
353352

354353
</details>

install/install_requirements.sh

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,23 @@ fi
4141
)
4242

4343
# Since torchchat often uses main-branch features of pytorch, only the nightly
44-
# pip versions will have the required features. The NIGHTLY_VERSION value should
44+
# pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should
4545
# agree with the third-party/pytorch pinned submodule commit.
4646
#
4747
# NOTE: If a newly-fetched version of the executorch repo changes the value of
48-
# NIGHTLY_VERSION, you should re-run this script to install the necessary
48+
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
4949
# package versions.
50-
NIGHTLY_VERSION=dev20240814
50+
PYTORCH_NIGHTLY_VERSION=dev20240814
51+
52+
# Nightly version for torchvision
53+
VISION_NIGHTLY_VERSION=dev20240814
54+
55+
# Nightly version for torchao
56+
AO_NIGHTLY_VERSION=dev20240905
57+
58+
# Nightly version for torchtune
59+
TUNE_NIGHTLY_VERSION=dev20240910
60+
5161

5262
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
5363
(
@@ -67,23 +77,45 @@ fi
6777

6878
# pip packages needed by exir.
6979
REQUIREMENTS_TO_INSTALL=(
70-
torch=="2.5.0.${NIGHTLY_VERSION}"
80+
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
81+
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
82+
)
83+
84+
LINUX_REQUIREMENTS_TO_INSTALL=(
85+
torchao=="0.5.0.${AO_NIGHTLY_VERSION}"
86+
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
7187
)
7288

73-
# Install the requirements. `--extra-index-url` tells pip to look for package
89+
# Install the requirements. --extra-index-url tells pip to look for package
7490
# versions on the provided URL if they aren't available on the default URL.
7591
(
7692
set -x
7793
$PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \
7894
"${REQUIREMENTS_TO_INSTALL[@]}"
7995
)
8096

81-
# For torchao need to install from github since nightly build doesn't have macos build.
82-
# TODO: Remove this and install nightly build, once it supports macos
83-
(
84-
set -x
85-
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3
86-
)
97+
PLATFORM=$(uname -s)
98+
99+
# Install torchtune and torchao requirements for Linux systems using nightly.
100+
# For non-Linux systems (e.g., macOS), install torchao from GitHub since nightly
101+
# build doesn't have macOS build.
102+
# TODO: Remove this and install nightly build, once it supports macOS
103+
if [ "$PLATFORM" == "Linux" ];
104+
then
105+
(
106+
set -x
107+
$PIP_EXECUTABLE install --pre --extra-index-url "${TORCH_NIGHTLY_URL}" --no-cache-dir \
108+
"${LINUX_REQUIREMENTS_TO_INSTALL[@]}"
109+
)
110+
else
111+
# For torchao need to install from github since nightly build doesn't have macos build.
112+
# TODO: Remove this and install nightly build, once it supports macos
113+
(
114+
set -x
115+
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3
116+
)
117+
fi
118+
87119
if [[ -x "$(command -v nvidia-smi)" ]]; then
88120
(
89121
set -x

torchchat/model.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from dataclasses import dataclass
1111
from enum import Enum
1212
from pathlib import Path
13-
from typing import Callable, Dict, Optional, Union
13+
14+
from typing import Any, Callable, Dict, Optional, Union
1415
from abc import ABC, abstractmethod
1516

1617
import torch
@@ -132,7 +133,7 @@ class TransformerArgs:
132133
ffn_dim_multiplier: Optional[int] = None
133134
use_tiktoken: bool = False
134135
max_seq_length: int = 8192
135-
use_scaled_rope: bool = False
136+
rope_scaling: Optional[Dict[str, Any]] = None
136137
# For pipeline parallel
137138
n_stages: int = 1
138139
stage_idx: int = 0
@@ -418,8 +419,6 @@ def __init__(self, config: TransformerArgs) -> None:
418419
self.norm = None
419420
self.output = None
420421

421-
# self.freqs_cis: Optional[Tensor] = None
422-
# self.mask_cache: Optional[Tensor] = None
423422
self.max_batch_size = -1
424423
self.max_seq_length = -1
425424
# For supporting sequence parallel (default is off, thus value of 1)
@@ -444,7 +443,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
444443
self.config.dim // self.config.n_heads,
445444
self.config.block_size * 2,
446445
self.config.rope_base,
447-
use_scaled=self.config.use_scaled_rope,
446+
rope_scaling=self.config.rope_scaling,
448447
)
449448
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
450449
causal_mask = torch.tril(
@@ -681,12 +680,16 @@ def forward(self, x: Tensor) -> Tensor:
681680
return output * self.weight
682681

683682

684-
def apply_scaling(freqs: torch.Tensor):
685-
# Values obtained from grid search
686-
scale_factor = 8
687-
low_freq_factor = 1
688-
high_freq_factor = 4
689-
old_context_len = 8192 # original llama3 length
683+
def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]):
684+
# Check for the presence of the required keys
685+
required_keys = {"factor", "low_freq_factor", "high_freq_factor", "original_max_position_embeddings"}
686+
if not required_keys.issubset(rope_scaling.keys()):
687+
raise ValueError(f"Missing required keys in apply_scaling. Expected: {required_keys}")
688+
689+
scale_factor = rope_scaling["factor"]
690+
low_freq_factor = rope_scaling["low_freq_factor"]
691+
high_freq_factor = rope_scaling["high_freq_factor"]
692+
old_context_len = rope_scaling["original_max_position_embeddings"]
690693

691694
low_freq_wavelen = old_context_len / low_freq_factor
692695
high_freq_wavelen = old_context_len / high_freq_factor
@@ -707,16 +710,16 @@ def apply_scaling(freqs: torch.Tensor):
707710

708711

709712
def precompute_freqs_cis(
710-
n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False
713+
n_elem: int, seq_len: int, base: int = 10000, dtype=None, rope_scaling: Optional[Dict[str, Any]] = None
711714
) -> Tensor:
712715
if not dtype:
713716
dtype = get_precision()
714717
freqs = 1.0 / (
715718
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
716719
)
717720
t = torch.arange(seq_len, device=freqs.device)
718-
if use_scaled:
719-
freqs = apply_scaling(freqs)
721+
if rope_scaling is not None:
722+
freqs = apply_scaling(freqs, rope_scaling)
720723
freqs = torch.outer(t, freqs)
721724
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
722725
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
1+
{"block_size": 8192, "dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
1+
{"block_size": 8192, "dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true}
1+
{"block_size": 131072, "dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "rope_scaling": {"factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192}}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true}
1+
{"block_size": 131072, "dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "rope_scaling": {"factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192}}

0 commit comments

Comments
 (0)