Skip to content

Commit 21877b0

Browse files
LiuXiaoxuanPKUwinglianWoosukKwon
authored
Support Longchat and RoPE scaling (#555)
Co-authored-by: Wing Lian <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]>
1 parent cf5cb1e commit 21877b0

File tree

4 files changed

+211
-40
lines changed

4 files changed

+211
-40
lines changed

vllm/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,17 @@ def _get_and_verify_max_len(
351351
if max_len_key is not None:
352352
derived_max_model_len = min(derived_max_model_len, max_len_key)
353353

354+
rope_scaling = getattr(hf_config, "rope_scaling", None)
355+
if rope_scaling is not None:
356+
if derived_max_model_len == float("inf"):
357+
raise ValueError(
358+
"When using rope_scaling, the model's config.json must "
359+
"contain one of the following keys to determine the original "
360+
f"maximum length of the model: {possible_keys}")
361+
assert "factor" in rope_scaling
362+
scaling_factor = rope_scaling["factor"]
363+
derived_max_model_len *= scaling_factor
364+
354365
if max_model_len is None:
355366
max_model_len = derived_max_model_len
356367
elif max_model_len > derived_max_model_len:

vllm/model_executor/layers/attention.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Multi-head attention."""
2-
from typing import List, Optional
2+
from typing import Any, Dict, List, Optional
33

44
import torch
55
import torch.nn as nn
@@ -9,8 +9,10 @@
99

1010
from vllm import attention_ops
1111
from vllm import cache_ops
12-
from vllm import pos_encoding_ops
1312
from vllm.model_executor.input_metadata import InputMetadata
13+
from vllm.model_executor.layers.rotary_embedding import (
14+
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
15+
RotaryEmbedding)
1416

1517
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
1618

@@ -247,7 +249,7 @@ def forward(
247249

248250

249251
class PagedAttentionWithRoPE(PagedAttention):
250-
"""PagedAttention with rotary embedding."""
252+
"""PagedAttention with rotary positional embedding."""
251253

252254
def __init__(
253255
self,
@@ -259,34 +261,26 @@ def __init__(
259261
base: int = 10000,
260262
num_kv_heads: Optional[int] = None,
261263
is_neox_style: bool = True,
264+
rope_scaling: Optional[Dict[str, Any]] = None,
262265
) -> None:
263266
super().__init__(num_heads, head_size, scale, num_kv_heads)
264-
self.is_neox_style = is_neox_style
265-
266-
# Create the cos and sin cache.
267-
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
268-
# However, we use `torch.arange(..., dtype=torch.float)` instead to
269-
# avoid numerical issues with large base values (e.g., 10000000).
270-
# This may cause a slight numerical difference between the HF
271-
# implementation and ours.
272-
# NOTE(woosuk): To exactly match the HF implementation, we need to
273-
# use CPU to compute the cache and then move it to GPU. However, we
274-
# create the cache on GPU for faster initialization. This may cause
275-
# a slight numerical difference between the HF implementation and ours.
276-
inv_freq = 1.0 / (base**(torch.arange(
277-
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
278-
t = torch.arange(max_position, dtype=torch.float, device="cuda")
279-
freqs = torch.einsum("i,j -> ij", t, inv_freq)
280-
cos = freqs.cos()
281-
sin = freqs.sin()
282-
cache = torch.cat((cos, sin), dim=-1)
283-
284-
# FIXME(woosuk): This assumes that we configure the default dtype when
285-
# initializing the model.
286-
torch_dtype = torch.get_default_dtype()
287-
cache = cache.to(torch_dtype)
288-
# Embedding size: [max_position, rotary_dim]
289-
self.register_buffer("cos_sin_cache", cache, persistent=False)
267+
if rope_scaling is None:
268+
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
269+
max_position, base,
270+
is_neox_style)
271+
else:
272+
scaling_type = rope_scaling["type"]
273+
scaling_factor = rope_scaling["factor"]
274+
if scaling_type == "linear":
275+
self.rotary_emb = LinearScalingRotaryEmbedding(
276+
head_size, rotary_dim, max_position, base, is_neox_style,
277+
scaling_factor)
278+
elif scaling_type == "dynamic":
279+
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
280+
head_size, rotary_dim, max_position, base, is_neox_style,
281+
scaling_factor)
282+
else:
283+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
290284

291285
def forward(
292286
self,
@@ -303,7 +297,7 @@ def forward(
303297
304298
Args:
305299
positions: shape = [num_tokens]
306-
query: shape = [num_tokens, num_heads * head_size]
300+
query: shape = [num_tokens, num_heads * head_size]
307301
key: shape = [num_tokens, num_kv_heads * head_size]
308302
value: shape = [num_tokens, num_kv_heads * head_size]
309303
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
@@ -319,14 +313,7 @@ def forward(
319313

320314
# Apply rotary embedding to the query and key before passing them
321315
# to the attention op.
322-
pos_encoding_ops.rotary_embedding(
323-
positions,
324-
query,
325-
key,
326-
self.head_size,
327-
self.cos_sin_cache,
328-
self.is_neox_style,
329-
)
316+
query, key = self.rotary_emb(positions, query, key)
330317
return super().forward(
331318
query,
332319
key,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# coding=utf-8
2+
# Adapted from
3+
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
4+
# Copyright 2023 The vLLM team.
5+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6+
#
7+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8+
# and OPT implementations in this library. It has been modified from its
9+
# original forms to accommodate minor architectural differences compared
10+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11+
#
12+
# Licensed under the Apache License, Version 2.0 (the "License");
13+
# you may not use this file except in compliance with the License.
14+
# You may obtain a copy of the License at
15+
#
16+
# http://www.apache.org/licenses/LICENSE-2.0
17+
#
18+
# Unless required by applicable law or agreed to in writing, software
19+
# distributed under the License is distributed on an "AS IS" BASIS,
20+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
# See the License for the specific language governing permissions and
22+
# limitations under the License.
23+
"""Rotary Positional Embeddings."""
24+
from typing import Tuple, Union
25+
26+
import torch
27+
import torch.nn as nn
28+
29+
from vllm import pos_encoding_ops
30+
31+
32+
class RotaryEmbedding(nn.Module):
33+
"""Original rotary positional embedding."""
34+
35+
def __init__(
36+
self,
37+
head_size: int,
38+
rotary_dim: int,
39+
max_position_embeddings: int,
40+
base: int,
41+
is_neox_style: bool,
42+
) -> None:
43+
super().__init__()
44+
self.head_size = head_size
45+
self.rotary_dim = rotary_dim
46+
self.max_position_embeddings = max_position_embeddings
47+
self.base = base
48+
self.is_neox_style = is_neox_style
49+
50+
cache = self._compute_cos_sin_cache()
51+
cache = cache.to(torch.get_default_dtype())
52+
self.register_buffer("cos_sin_cache", cache, persistent=False)
53+
54+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
55+
"""Compute the inverse frequency."""
56+
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
57+
# However, we use `torch.arange(..., dtype=torch.float)` instead to
58+
# avoid numerical issues with large base values (e.g., 10000000).
59+
# This may cause a slight numerical difference between the HF
60+
# implementation and ours.
61+
# NOTE(woosuk): To exactly match the HF implementation, we need to
62+
# use CPU to compute the cache and then move it to GPU. However, we
63+
# create the cache on GPU for faster initialization. This may cause
64+
# a slight numerical difference between the HF implementation and ours.
65+
inv_freq = 1.0 / (base**(torch.arange(
66+
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
67+
self.rotary_dim))
68+
return inv_freq
69+
70+
def _compute_cos_sin_cache(self) -> torch.Tensor:
71+
"""Compute the cos and sin cache."""
72+
inv_freq = self._compute_inv_freq(self.base)
73+
t = torch.arange(self.max_position_embeddings,
74+
dtype=torch.float,
75+
device="cuda")
76+
77+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
78+
cos = freqs.cos()
79+
sin = freqs.sin()
80+
cache = torch.cat((cos, sin), dim=-1)
81+
return cache
82+
83+
def forward(
84+
self,
85+
positions: torch.Tensor,
86+
query: torch.Tensor,
87+
key: torch.Tensor,
88+
) -> Tuple[torch.Tensor, torch.Tensor]:
89+
# pos_encoding_ops.rotary_embedding() is an in-place operation that
90+
# updates the query and key tensors.
91+
pos_encoding_ops.rotary_embedding(positions, query, key,
92+
self.head_size, self.cos_sin_cache,
93+
self.is_neox_style)
94+
return query, key
95+
96+
97+
class LinearScalingRotaryEmbedding(RotaryEmbedding):
98+
"""RotaryEmbedding extended with linear scaling.
99+
100+
Credits to the Reddit user /u/kaiokendev
101+
"""
102+
103+
def __init__(
104+
self,
105+
head_size: int,
106+
rotary_dim: int,
107+
max_position_embeddings: int,
108+
base: int,
109+
is_neox_style: bool,
110+
scaling_factor: float,
111+
) -> None:
112+
self.scaling_factor = scaling_factor
113+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
114+
is_neox_style)
115+
116+
def _compute_cos_sin_cache(self) -> torch.Tensor:
117+
inv_freq = self._compute_inv_freq(self.base)
118+
# NOTE(woosuk): self.max_position_embeddings is the original
119+
# maximum length before applying the rope scaling.
120+
# Thus, the maximum length after applying the rope scaling is
121+
# self.max_position_embeddings * self.scaling_factor.
122+
max_len = self.max_position_embeddings * self.scaling_factor
123+
t = torch.arange(max_len, dtype=torch.float, device="cuda")
124+
t = t / self.scaling_factor
125+
126+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
127+
cos = freqs.cos()
128+
sin = freqs.sin()
129+
cache = torch.cat((cos, sin), dim=-1)
130+
return cache
131+
132+
133+
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
134+
"""RotaryEmbedding extended with Dynamic NTK scaling.
135+
136+
Credits to the Reddit users /u/bloc97 and /u/emozilla
137+
"""
138+
139+
def __init__(
140+
self,
141+
head_size: int,
142+
rotary_dim: int,
143+
max_position_embeddings: int,
144+
base: int,
145+
is_neox_style: bool,
146+
scaling_factor: float,
147+
) -> None:
148+
self.scaling_factor = scaling_factor
149+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
150+
is_neox_style)
151+
152+
def _compute_cos_sin_cache(self) -> torch.Tensor:
153+
# NOTE(woosuk): self.max_position_embeddings is the original
154+
# maximum length before applying the rope scaling.
155+
# Thus, the maximum length after applying the rope scaling is
156+
# self.max_position_embeddings * self.scaling_factor.
157+
max_len = self.max_position_embeddings * self.scaling_factor
158+
base = self.base * (
159+
(self.scaling_factor * max_len / self.max_position_embeddings) -
160+
(self.scaling_factor - 1))**(self.rotary_dim /
161+
(self.rotary_dim - 2))
162+
inv_freq = self._compute_inv_freq(base)
163+
t = torch.arange(max_len, dtype=torch.float, device="cuda")
164+
165+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
166+
cos = freqs.cos()
167+
sin = freqs.sin()
168+
cache = torch.cat((cos, sin), dim=-1)
169+
return cache

vllm/model_executor/models/llama.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
The input of the model is flattened to a 1D tensor of tokens. The model uses
2626
InputMetadata to extract the original 2D shape of the input.
2727
"""
28-
from typing import List, Optional, Tuple
28+
from typing import Any, Dict, List, Optional, Tuple
2929

3030
import torch
3131
from torch import nn
@@ -92,6 +92,7 @@ def __init__(
9292
num_heads: int,
9393
num_kv_heads: int,
9494
rope_theta: float = 10000,
95+
rope_scaling: Optional[Dict[str, Any]] = None,
9596
max_position_embeddings: int = 8192,
9697
quant_config: Optional[QuantizationConfig] = None,
9798
) -> None:
@@ -135,7 +136,8 @@ def __init__(
135136
base=self.rope_theta,
136137
max_position=self.max_position_embeddings,
137138
rotary_dim=self.head_dim,
138-
num_kv_heads=self.num_kv_heads)
139+
num_kv_heads=self.num_kv_heads,
140+
rope_scaling=rope_scaling)
139141

140142
def forward(
141143
self,
@@ -165,13 +167,15 @@ def __init__(
165167
self.hidden_size = config.hidden_size
166168
# Requires transformers > 4.32.0
167169
rope_theta = getattr(config, "rope_theta", 10000)
170+
rope_scaling = getattr(config, "rope_scaling", None)
168171
max_position_embeddings = getattr(config, "max_position_embeddings",
169172
8192)
170173
self.self_attn = LlamaAttention(
171174
hidden_size=self.hidden_size,
172175
num_heads=config.num_attention_heads,
173176
num_kv_heads=config.num_key_value_heads,
174177
rope_theta=rope_theta,
178+
rope_scaling=rope_scaling,
175179
max_position_embeddings=max_position_embeddings,
176180
quant_config=quant_config,
177181
)

0 commit comments

Comments
 (0)