|
10 | 10 | from vllm import attention_ops
|
11 | 11 | from vllm import cache_ops
|
12 | 12 | from vllm.model_executor.input_metadata import InputMetadata
|
13 |
| -from vllm.model_executor.layers.rotary_embedding import ( |
14 |
| - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, |
15 |
| - RotaryEmbedding, YaRNScalingRotaryEmbedding) |
| 13 | +from vllm.model_executor.layers.rotary_embedding import get_rope |
16 | 14 |
|
17 | 15 | _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
18 | 16 | # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
@@ -319,36 +317,8 @@ def __init__(
|
319 | 317 | scale,
|
320 | 318 | num_kv_heads,
|
321 | 319 | sliding_window=sliding_window)
|
322 |
| - if rope_scaling is None: |
323 |
| - self.rotary_emb = RotaryEmbedding(head_size, rotary_dim, |
324 |
| - max_position, base, |
325 |
| - is_neox_style) |
326 |
| - else: |
327 |
| - scaling_type = rope_scaling["type"] |
328 |
| - scaling_factor = rope_scaling["factor"] |
329 |
| - if scaling_type == "linear": |
330 |
| - self.rotary_emb = LinearScalingRotaryEmbedding( |
331 |
| - head_size, rotary_dim, max_position, base, is_neox_style, |
332 |
| - scaling_factor) |
333 |
| - elif scaling_type == "dynamic": |
334 |
| - self.rotary_emb = DynamicNTKScalingRotaryEmbedding( |
335 |
| - head_size, rotary_dim, max_position, base, is_neox_style, |
336 |
| - scaling_factor) |
337 |
| - elif scaling_type == "yarn": |
338 |
| - original_max_position = rope_scaling[ |
339 |
| - "original_max_position_embeddings"] |
340 |
| - assert max_position == original_max_position * scaling_factor |
341 |
| - extra_kwargs = { |
342 |
| - k: v |
343 |
| - for k, v in rope_scaling.items() |
344 |
| - if k in ("extrapolation_factor", "attn_factor", |
345 |
| - "beta_fast", "beta_slow") |
346 |
| - } |
347 |
| - self.rotary_emb = YaRNScalingRotaryEmbedding( |
348 |
| - head_size, rotary_dim, original_max_position, base, |
349 |
| - is_neox_style, scaling_factor, **extra_kwargs) |
350 |
| - else: |
351 |
| - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
| 320 | + self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base, |
| 321 | + is_neox_style, rope_scaling) |
352 | 322 |
|
353 | 323 | def forward(
|
354 | 324 | self,
|
|
0 commit comments