Skip to content

Commit 054072b

Browse files
authored
[Minor] Move RoPE selection logic to get_rope (#1633)
1 parent eb825c1 commit 054072b

File tree

2 files changed

+47
-34
lines changed

2 files changed

+47
-34
lines changed

vllm/model_executor/layers/attention.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from vllm import attention_ops
1111
from vllm import cache_ops
1212
from vllm.model_executor.input_metadata import InputMetadata
13-
from vllm.model_executor.layers.rotary_embedding import (
14-
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
15-
RotaryEmbedding, YaRNScalingRotaryEmbedding)
13+
from vllm.model_executor.layers.rotary_embedding import get_rope
1614

1715
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
1816
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -319,36 +317,8 @@ def __init__(
319317
scale,
320318
num_kv_heads,
321319
sliding_window=sliding_window)
322-
if rope_scaling is None:
323-
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
324-
max_position, base,
325-
is_neox_style)
326-
else:
327-
scaling_type = rope_scaling["type"]
328-
scaling_factor = rope_scaling["factor"]
329-
if scaling_type == "linear":
330-
self.rotary_emb = LinearScalingRotaryEmbedding(
331-
head_size, rotary_dim, max_position, base, is_neox_style,
332-
scaling_factor)
333-
elif scaling_type == "dynamic":
334-
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
335-
head_size, rotary_dim, max_position, base, is_neox_style,
336-
scaling_factor)
337-
elif scaling_type == "yarn":
338-
original_max_position = rope_scaling[
339-
"original_max_position_embeddings"]
340-
assert max_position == original_max_position * scaling_factor
341-
extra_kwargs = {
342-
k: v
343-
for k, v in rope_scaling.items()
344-
if k in ("extrapolation_factor", "attn_factor",
345-
"beta_fast", "beta_slow")
346-
}
347-
self.rotary_emb = YaRNScalingRotaryEmbedding(
348-
head_size, rotary_dim, original_max_position, base,
349-
is_neox_style, scaling_factor, **extra_kwargs)
350-
else:
351-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
320+
self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
321+
is_neox_style, rope_scaling)
352322

353323
def forward(
354324
self,

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# limitations under the License.
2323
"""Rotary Positional Embeddings."""
2424
import math
25-
from typing import Tuple, Union
25+
from typing import Any, Dict, Optional, Tuple, Union
2626

2727
import torch
2828
import torch.nn as nn
@@ -271,3 +271,46 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
271271
sin = (freqs.sin() * self.mscale)
272272
cache = torch.cat((cos, sin), dim=-1)
273273
return cache
274+
275+
276+
def get_rope(
277+
head_size: int,
278+
rotary_dim: int,
279+
max_position: int,
280+
base: int,
281+
is_neox_style: bool,
282+
rope_scaling: Optional[Dict[str, Any]],
283+
) -> RotaryEmbedding:
284+
if rope_scaling is None:
285+
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
286+
is_neox_style)
287+
else:
288+
scaling_type = rope_scaling["type"]
289+
scaling_factor = rope_scaling["factor"]
290+
if scaling_type == "linear":
291+
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
292+
max_position, base,
293+
is_neox_style,
294+
scaling_factor)
295+
elif scaling_type == "dynamic":
296+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
297+
head_size, rotary_dim, max_position, base, is_neox_style,
298+
scaling_factor)
299+
elif scaling_type == "yarn":
300+
original_max_position = rope_scaling[
301+
"original_max_position_embeddings"]
302+
assert max_position == original_max_position * scaling_factor
303+
extra_kwargs = {
304+
k: v
305+
for k, v in rope_scaling.items()
306+
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
307+
"beta_slow")
308+
}
309+
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
310+
original_max_position,
311+
base, is_neox_style,
312+
scaling_factor,
313+
**extra_kwargs)
314+
else:
315+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
316+
return rotary_emb

0 commit comments

Comments
 (0)