Skip to content

Commit 44b5ce9

Browse files
[Bugfix] In LongRoPE, decide short vs long based on max_model_len (#27431)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent 7a865f2 commit 44b5ce9

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

tests/entrypoints/openai/test_default_mm_loras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def multimodal_server(): # noqa: F811
2929
"--dtype",
3030
"half",
3131
"--max-model-len",
32-
"12800",
32+
"4096",
3333
"--enforce-eager",
3434
# lora config below
3535
"--enable-lora",

vllm/config/model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2142,8 +2142,18 @@ def _get_and_verify_max_len(
21422142
# If the user didn't specify `max_model_len`, then use that derived from
21432143
# the model config as a default value.
21442144
if max_model_len is None:
2145-
max_model_len = int(derived_max_model_len)
2145+
# For LongRoPE, default to original_max_position_embeddings to avoid
2146+
# performance degradation for shorter sequences
2147+
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
2148+
max_model_len = int(
2149+
getattr(
2150+
hf_config, "original_max_position_embeddings", derived_max_model_len
2151+
)
2152+
)
2153+
else:
2154+
max_model_len = int(derived_max_model_len)
21462155
max_model_len = current_platform.check_max_model_len(max_model_len)
2156+
21472157
# If the user specified a max length, make sure it is smaller than the
21482158
# derived length from the HF model config.
21492159
elif max_model_len > derived_max_model_len:

vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
import torch
66
import torch.nn as nn
77

8+
from vllm.config import get_current_vllm_config
9+
from vllm.logger import init_logger
10+
811
from .common import rotate_neox
912

13+
logger = init_logger(__name__)
14+
1015

1116
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
1217
"""Phi3 family of models scaled rotary embedding.
@@ -43,6 +48,22 @@ def __init__(
4348
self.short_factor = short_factor
4449
self.long_factor = long_factor
4550

51+
# Force long factors if max_model_len (runtime max length) exceeds
52+
# original_max_position_embeddings to prevent KV cache invalidation when
53+
# sequences cross this threshold during generation
54+
max_model_len = get_current_vllm_config().model_config.max_model_len
55+
self.use_long_rope = max_model_len > original_max_position_embeddings
56+
if self.use_long_rope:
57+
logger.warning_once(
58+
"Using LongRoPE scaling factors. This enables longer "
59+
"contexts (%d tokens vs original %d tokens) at the cost of "
60+
"some performance degradation for shorter sequences. If "
61+
"this is not desired, set `max_model_len` to be at most %d.",
62+
max_position_embeddings,
63+
original_max_position_embeddings,
64+
original_max_position_embeddings,
65+
)
66+
4667
scale = self.max_position_embeddings / self.original_max_position_embeddings
4768
if scale <= 1.0:
4869
scaling_factor = 1.0
@@ -112,15 +133,12 @@ def forward(
112133
query = query.view(*query.shape[:-1], -1, self.head_size)
113134
key = key.view(*key.shape[:-1], -1, self.head_size)
114135

115-
k = self.original_max_position_embeddings
116-
long_prompt_offset = (
117-
torch.any(positions > k).float() * torch.full_like(positions, k)
118-
).long()
119-
idx = (
120-
torch.add(positions, long_prompt_offset)
121-
if long_prompt_offset is not None
122-
else positions
123-
)
136+
if self.use_long_rope:
137+
k = self.original_max_position_embeddings
138+
long_prompt_offset = torch.full_like(positions, k).long()
139+
idx = torch.add(positions, long_prompt_offset)
140+
else:
141+
idx = positions
124142
idx = torch.add(idx, offsets) if offsets is not None else idx
125143
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
126144

0 commit comments

Comments
 (0)