|
5 | 5 | import torch |
6 | 6 | import torch.nn as nn |
7 | 7 |
|
| 8 | +from vllm.config import get_current_vllm_config |
| 9 | +from vllm.logger import init_logger |
| 10 | + |
8 | 11 | from .common import rotate_neox |
9 | 12 |
|
| 13 | +logger = init_logger(__name__) |
| 14 | + |
10 | 15 |
|
11 | 16 | class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): |
12 | 17 | """Phi3 family of models scaled rotary embedding. |
@@ -43,6 +48,22 @@ def __init__( |
43 | 48 | self.short_factor = short_factor |
44 | 49 | self.long_factor = long_factor |
45 | 50 |
|
| 51 | + # Force long factors if max_model_len (runtime max length) exceeds |
| 52 | + # original_max_position_embeddings to prevent KV cache invalidation when |
| 53 | + # sequences cross this threshold during generation |
| 54 | + max_model_len = get_current_vllm_config().model_config.max_model_len |
| 55 | + self.use_long_rope = max_model_len > original_max_position_embeddings |
| 56 | + if self.use_long_rope: |
| 57 | + logger.warning_once( |
| 58 | + "Using LongRoPE scaling factors. This enables longer " |
| 59 | + "contexts (%d tokens vs original %d tokens) at the cost of " |
| 60 | + "some performance degradation for shorter sequences. If " |
| 61 | + "this is not desired, set `max_model_len` to be at most %d.", |
| 62 | + max_position_embeddings, |
| 63 | + original_max_position_embeddings, |
| 64 | + original_max_position_embeddings, |
| 65 | + ) |
| 66 | + |
46 | 67 | scale = self.max_position_embeddings / self.original_max_position_embeddings |
47 | 68 | if scale <= 1.0: |
48 | 69 | scaling_factor = 1.0 |
@@ -112,15 +133,12 @@ def forward( |
112 | 133 | query = query.view(*query.shape[:-1], -1, self.head_size) |
113 | 134 | key = key.view(*key.shape[:-1], -1, self.head_size) |
114 | 135 |
|
115 | | - k = self.original_max_position_embeddings |
116 | | - long_prompt_offset = ( |
117 | | - torch.any(positions > k).float() * torch.full_like(positions, k) |
118 | | - ).long() |
119 | | - idx = ( |
120 | | - torch.add(positions, long_prompt_offset) |
121 | | - if long_prompt_offset is not None |
122 | | - else positions |
123 | | - ) |
| 136 | + if self.use_long_rope: |
| 137 | + k = self.original_max_position_embeddings |
| 138 | + long_prompt_offset = torch.full_like(positions, k).long() |
| 139 | + idx = torch.add(positions, long_prompt_offset) |
| 140 | + else: |
| 141 | + idx = positions |
124 | 142 | idx = torch.add(idx, offsets) if offsets is not None else idx |
125 | 143 | cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) |
126 | 144 |
|
|
0 commit comments