Skip to content

Commit 7bedab5

Browse files
authored
Add rope_scaling to Qwen (#1210)
1 parent 20f7cc4 commit 7bedab5

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

vllm/model_executor/models/qwen.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
The input of the model is flattened to a 1D tensor of tokens. The model uses
99
InputMetadata to extract the original 2D shape of the input.
1010
"""
11-
from typing import List, Optional, Tuple
11+
from typing import Any, Dict, List, Optional, Tuple
1212

1313
import torch
1414
from torch import nn
@@ -76,13 +76,12 @@ def forward(self, x):
7676

7777
class QWenAttention(nn.Module):
7878

79-
def __init__(
80-
self,
81-
hidden_size: int,
82-
num_heads: int,
83-
max_position_embeddings: int,
84-
rope_theta: float = 10000,
85-
):
79+
def __init__(self,
80+
hidden_size: int,
81+
num_heads: int,
82+
max_position_embeddings: int,
83+
rope_theta: float = 10000,
84+
rope_scaling: Optional[Dict[str, Any]] = None):
8685
super().__init__()
8786
self.hidden_size = hidden_size
8887
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
@@ -116,7 +115,7 @@ def __init__(
116115
rotary_dim=self.head_dim,
117116
base=rope_theta,
118117
max_position=max_position_embeddings,
119-
)
118+
rope_scaling=rope_scaling)
120119

121120
def forward(
122121
self,
@@ -144,10 +143,12 @@ def __init__(self, config: QWenConfig):
144143
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
145144

146145
rope_theta = getattr(config, "rope_theta", 10000)
146+
rope_scaling = getattr(config, "rope_scaling", None)
147147
self.attn = QWenAttention(config.hidden_size,
148148
config.num_attention_heads,
149149
config.max_position_embeddings,
150-
rope_theta=rope_theta)
150+
rope_theta=rope_theta,
151+
rope_scaling=rope_scaling)
151152

152153
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
153154

0 commit comments

Comments
 (0)