1818from vllm .model_executor .layers .pooler import (CrossEncodingPooler , Pooler ,
1919 PoolingType )
2020from vllm .model_executor .layers .quantization import QuantizationConfig
21+ from vllm .model_executor .layers .rotary_embedding import get_rope
2122from vllm .model_executor .layers .vocab_parallel_embedding import (
2223 VocabParallelEmbedding )
2324from vllm .model_executor .model_loader .weight_utils import default_weight_loader
@@ -38,19 +39,24 @@ def __init__(self, config: BertConfig):
3839 self .size = config .hidden_size
3940 self .word_embeddings = VocabParallelEmbedding (config .vocab_size ,
4041 config .hidden_size )
41- self .position_embeddings = VocabParallelEmbedding (
42- config .max_position_embeddings , config .hidden_size )
42+
4343 self .token_type_embeddings = VocabParallelEmbedding (
4444 config .type_vocab_size , config .hidden_size )
4545 self .LayerNorm = nn .LayerNorm (config .hidden_size ,
4646 eps = config .layer_norm_eps )
47- self .position_ids = nn .Parameter (
48- torch .empty ((1 , config .max_position_embeddings )), )
4947
5048 self .position_embedding_type = config .position_embedding_type
51- if self .position_embedding_type != "absolute" :
52- raise ValueError ("Only 'absolute' position_embedding_type" +
53- " is supported" )
49+ if self .position_embedding_type == "absolute" :
50+ self .position_embeddings = VocabParallelEmbedding (
51+ config .max_position_embeddings , config .hidden_size )
52+ self .position_ids = nn .Parameter (
53+ torch .empty ((1 , config .max_position_embeddings )), )
54+ elif self .position_embedding_type == "rotary" :
55+ self .position_embeddings = None
56+ self .position_ids = None
57+ else :
58+ raise ValueError ("Only 'absolute' and 'rotary' " +
59+ "position_embedding_type is supported" )
5460
5561 def forward (
5662 self ,
@@ -64,17 +70,19 @@ def forward(
6470 # Input embeddings.
6571 inputs_embeds = self .word_embeddings (input_ids )
6672
67- # Position embeddings.
68- position_embeddings = self .position_embeddings (position_ids )
69-
7073 if token_type_ids is None :
7174 token_type_ids = torch .zeros (input_shape ,
7275 dtype = torch .long ,
7376 device = inputs_embeds .device )
7477
7578 token_type_embeddings = self .token_type_embeddings (token_type_ids )
7679
77- embeddings = inputs_embeds + token_type_embeddings + position_embeddings
80+ embeddings = inputs_embeds + token_type_embeddings
81+
82+ if self .position_embedding_type == "absolute" :
83+ position_embeddings = self .position_embeddings (position_ids )
84+ embeddings += position_embeddings
85+
7886 embeddings = self .LayerNorm (embeddings )
7987 return embeddings
8088
@@ -98,7 +106,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
98106@support_torch_compile
99107class BertEncoder (nn .Module ):
100108
101- def __init__ (self , vllm_config : VllmConfig , prefix : str = "" ):
109+ def __init__ (self ,
110+ vllm_config : VllmConfig ,
111+ rotary_kwargs : Optional [dict ] = None ,
112+ prefix : str = "" ):
102113 super ().__init__ ()
103114 config = vllm_config .model_config .hf_config
104115 cache_config = vllm_config .cache_config
@@ -107,16 +118,18 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
107118 BertLayer (config = config ,
108119 cache_config = cache_config ,
109120 quant_config = quant_config ,
121+ rotary_kwargs = rotary_kwargs ,
110122 prefix = f"{ prefix } .layer.{ layer_idx } " )
111123 for layer_idx in range (config .num_hidden_layers )
112124 ])
113125
114126 def forward (
115127 self ,
128+ positions : torch .Tensor ,
116129 hidden_states : torch .Tensor ,
117130 ) -> torch .Tensor :
118131 for layer in self .layer :
119- hidden_states = layer (hidden_states )
132+ hidden_states = layer (positions , hidden_states )
120133 return hidden_states
121134
122135
@@ -126,6 +139,7 @@ def __init__(self,
126139 config : BertConfig ,
127140 cache_config : Optional [CacheConfig ] = None ,
128141 quant_config : Optional [QuantizationConfig ] = None ,
142+ rotary_kwargs : Optional [dict ] = None ,
129143 prefix : str = "" ):
130144 super ().__init__ ()
131145
@@ -135,6 +149,7 @@ def __init__(self,
135149 layer_norm_eps = config .layer_norm_eps ,
136150 cache_config = cache_config ,
137151 quant_config = quant_config ,
152+ rotary_kwargs = rotary_kwargs ,
138153 prefix = f"{ prefix } .attention" )
139154
140155 self .intermediate = BertIntermediate (
@@ -150,8 +165,8 @@ def __init__(self,
150165 quant_config = quant_config ,
151166 prefix = f"{ prefix } .output" )
152167
153- def forward (self , hidden_states : torch .Tensor ):
154- attn_output = self .attention (hidden_states )
168+ def forward (self , positions : torch . Tensor , hidden_states : torch .Tensor ):
169+ attn_output = self .attention (positions , hidden_states )
155170 intermediate_output = self .intermediate (attn_output )
156171 output = self .output (intermediate_output , attn_output )
157172 return output
@@ -166,6 +181,7 @@ def __init__(
166181 layer_norm_eps : float ,
167182 cache_config : Optional [CacheConfig ] = None ,
168183 quant_config : Optional [QuantizationConfig ] = None ,
184+ rotary_kwargs : Optional [dict ] = None ,
169185 prefix : str = "" ,
170186 ):
171187 super ().__init__ ()
@@ -174,6 +190,7 @@ def __init__(
174190 num_attention_heads = num_attention_heads ,
175191 cache_config = cache_config ,
176192 quant_config = quant_config ,
193+ rotary_kwargs = rotary_kwargs ,
177194 prefix = f"{ prefix } .output" )
178195
179196 self .output = BertSelfOutput (hidden_size = hidden_size ,
@@ -183,9 +200,10 @@ def __init__(
183200
184201 def forward (
185202 self ,
203+ positions : torch .Tensor ,
186204 hidden_states : torch .Tensor ,
187205 ) -> torch .Tensor :
188- self_output = self .self (hidden_states )
206+ self_output = self .self (positions , hidden_states )
189207 return self .output (self_output , hidden_states )
190208
191209
@@ -197,6 +215,7 @@ def __init__(
197215 num_attention_heads : int ,
198216 cache_config : Optional [CacheConfig ] = None ,
199217 quant_config : Optional [QuantizationConfig ] = None ,
218+ rotary_kwargs : Optional [dict ] = None ,
200219 prefix : str = "" ,
201220 ):
202221 super ().__init__ ()
@@ -225,6 +244,11 @@ def __init__(
225244 quant_config = quant_config ,
226245 prefix = f"{ prefix } .qkv_proj" )
227246
247+ if rotary_kwargs :
248+ self .rotary_emb = get_rope (** rotary_kwargs )
249+ else :
250+ self .rotary_emb = None
251+
228252 self .attn = Attention (num_heads = self .num_heads ,
229253 head_size = self .head_dim ,
230254 scale = self .scaling ,
@@ -236,10 +260,15 @@ def __init__(
236260
237261 def forward (
238262 self ,
263+ positions : torch .Tensor ,
239264 hidden_states : torch .Tensor ,
240265 ) -> torch .Tensor :
241266 qkv , _ = self .qkv_proj (hidden_states )
242267 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
268+
269+ if self .rotary_emb :
270+ q , k = self .rotary_emb (positions , q , k )
271+
243272 output = self .attn (q , k , v )
244273 return output
245274
@@ -321,11 +350,13 @@ def __init__(self,
321350 vllm_config : VllmConfig ,
322351 prefix : str = "" ,
323352 embedding_class : type = BertEmbedding ,
353+ rotary_kwargs : Optional [dict ] = None ,
324354 add_pooling_layer : bool = False ):
325355 super ().__init__ ()
326356 config = vllm_config .model_config .hf_config
327357 self .embeddings = embedding_class (config )
328358 self .encoder = BertEncoder (vllm_config = vllm_config ,
359+ rotary_kwargs = rotary_kwargs ,
329360 prefix = f"{ prefix } .encoder" )
330361 self .pooler = BertPooler (config ) if add_pooling_layer else None
331362
@@ -347,7 +378,7 @@ def forward(
347378 seq_lens = attn_metadata .seq_lens_tensor ,
348379 position_ids = position_ids ,
349380 token_type_ids = token_type_ids )
350- return self .encoder (hidden_states )
381+ return self .encoder (position_ids , hidden_states )
351382
352383 def load_weights (self , weights : Iterable [Tuple [str ,
353384 torch .Tensor ]]) -> Set [str ]:
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
401432 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
402433 super ().__init__ ()
403434 pooler_config = vllm_config .model_config .pooler_config
435+ self .config = vllm_config .model_config .hf_config
404436 self .model = self ._build_model (vllm_config = vllm_config ,
405437 prefix = maybe_prefix (prefix , "model" ))
406438 self ._pooler = self ._build_pooler (pooler_config )
0 commit comments