22
22
The input of the model is flattened to a 1D tensor of tokens. The model uses
23
23
InputMetadata to extract the original 2D shape of the input.
24
24
"""
25
+ import math
25
26
from typing import Dict , List , Optional , Tuple
26
27
27
28
import torch
31
32
from vllm .model_executor .input_metadata import InputMetadata
32
33
from vllm .model_executor .layers .activation import SiluAndMul
33
34
from vllm .model_executor .layers .layernorm import RMSNorm
34
- from vllm .model_executor .layers .attention import PagedAttentionWithRoPE
35
+ from vllm .model_executor .layers .attention import PagedAttentionWithRoPE , PagedAttentionWithALiBi
35
36
from vllm .model_executor .layers .sampler import Sampler
36
37
from vllm .model_executor .weight_utils import (hf_model_weights_iterator ,
37
38
load_tensor_parallel_weights )
44
45
KVCache = Tuple [torch .Tensor , torch .Tensor ]
45
46
46
47
48
+ def _get_alibi_slopes (total_num_heads : int ) -> torch .Tensor :
49
+ closest_power_of_2 = 2 ** math .floor (math .log2 (total_num_heads ))
50
+ base = torch .tensor (
51
+ 2 ** (- (2 ** - (math .log2 (closest_power_of_2 ) - 3 ))),
52
+ dtype = torch .float32 ,
53
+ )
54
+ powers = torch .arange (1 , 1 + closest_power_of_2 , dtype = torch .int32 )
55
+ slopes = torch .pow (base , powers )
56
+
57
+ if closest_power_of_2 != total_num_heads :
58
+ extra_base = torch .tensor (
59
+ 2 ** (- (2 ** - (math .log2 (2 * closest_power_of_2 ) - 3 ))),
60
+ dtype = torch .float32 ,
61
+ )
62
+ num_remaining_heads = min (closest_power_of_2 ,
63
+ total_num_heads - closest_power_of_2 )
64
+ extra_powers = torch .arange (start = 1 ,
65
+ end = 1 + 2 * num_remaining_heads ,
66
+ step = 2 ,
67
+ dtype = torch .int32 )
68
+ slopes = torch .cat (
69
+ [slopes , torch .pow (extra_base , extra_powers )], dim = 0 )
70
+ return slopes
71
+
72
+
47
73
class BaiChuanMLP (nn .Module ):
48
74
49
75
def __init__ (
@@ -82,6 +108,7 @@ def __init__(
82
108
self ,
83
109
hidden_size : int ,
84
110
num_heads : int ,
111
+ position_embedding : str ,
85
112
):
86
113
super ().__init__ ()
87
114
self .hidden_size = hidden_size
@@ -92,7 +119,7 @@ def __init__(
92
119
self .num_heads = (self .total_num_heads //
93
120
tensor_model_parallel_world_size )
94
121
self .head_dim = hidden_size // self .total_num_heads
95
- self .scaling = self . head_dim ** - 0.5
122
+ self .postion_embedding = position_embedding
96
123
97
124
# pylint: disable=invalid-name
98
125
self .W_pack = ColumnParallelLinear (
@@ -109,11 +136,23 @@ def __init__(
109
136
input_is_parallel = True ,
110
137
perform_initialization = False ,
111
138
)
112
-
113
- self .attn = PagedAttentionWithRoPE (self .num_heads ,
114
- self .head_dim ,
115
- self .scaling ,
116
- rotary_dim = self .head_dim )
139
+ # Create the alibi slopes and slice them.
140
+ if self .postion_embedding == "ALIBI" :
141
+ tp_rank = get_tensor_model_parallel_rank ()
142
+ head_start = tp_rank * self .num_heads
143
+ head_end = (tp_rank + 1 ) * self .num_heads
144
+ alibi_slopes = _get_alibi_slopes (self .total_num_heads )
145
+ alibi_slopes = alibi_slopes [head_start :head_end ].tolist ()
146
+
147
+ scaling = self .head_dim ** - 0.5
148
+ self .attn = PagedAttentionWithALiBi (self .num_heads , self .head_dim ,
149
+ scaling , alibi_slopes )
150
+ else :
151
+ self .scaling = self .head_dim ** - 0.5
152
+ self .attn = PagedAttentionWithRoPE (self .num_heads ,
153
+ self .head_dim ,
154
+ self .scaling ,
155
+ rotary_dim = self .head_dim )
117
156
118
157
def forward (
119
158
self ,
@@ -126,20 +165,26 @@ def forward(
126
165
qkv , _ = self .W_pack (hidden_states )
127
166
q , k , v = qkv .chunk (chunks = 3 , dim = - 1 )
128
167
k_cache , v_cache = kv_cache
129
- attn_output = self .attn (positions , q , k , v , k_cache , v_cache ,
130
- input_metadata , cache_event )
168
+ if self .postion_embedding == "ALIBI" :
169
+ attn_output = self .attn (q , k , v , k_cache , v_cache , input_metadata ,
170
+ cache_event )
171
+ else :
172
+ attn_output = self .attn (positions , q , k , v , k_cache , v_cache ,
173
+ input_metadata , cache_event )
174
+
131
175
output , _ = self .o_proj (attn_output )
132
176
return output
133
177
134
178
135
179
class BaiChuanDecoderLayer (nn .Module ):
136
180
137
- def __init__ (self , config : BaiChuanConfig ):
181
+ def __init__ (self , config : BaiChuanConfig , position_embedding : str ):
138
182
super ().__init__ ()
139
183
self .hidden_size = config .hidden_size
140
184
self .self_attn = BaiChuanAttention (
141
185
hidden_size = self .hidden_size ,
142
186
num_heads = config .num_attention_heads ,
187
+ position_embedding = position_embedding ,
143
188
)
144
189
self .mlp = BaiChuanMLP (
145
190
hidden_size = self .hidden_size ,
@@ -181,7 +226,7 @@ def forward(
181
226
182
227
class BaiChuanModel (nn .Module ):
183
228
184
- def __init__ (self , config : BaiChuanConfig ):
229
+ def __init__ (self , config : BaiChuanConfig , position_embedding : str ):
185
230
super ().__init__ ()
186
231
self .config = config
187
232
self .padding_idx = config .pad_token_id
@@ -192,7 +237,7 @@ def __init__(self, config: BaiChuanConfig):
192
237
config .hidden_size ,
193
238
perform_initialization = False )
194
239
self .layers = nn .ModuleList ([
195
- BaiChuanDecoderLayer (config )
240
+ BaiChuanDecoderLayer (config , position_embedding )
196
241
for _ in range (config .num_hidden_layers )
197
242
])
198
243
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -223,12 +268,12 @@ def forward(
223
268
return hidden_states
224
269
225
270
226
- class BaiChuanForCausalLM (nn .Module ):
271
+ class BaiChuanBaseForCausalLM (nn .Module ):
227
272
228
- def __init__ (self , config ):
273
+ def __init__ (self , config , position_embedding : str ):
229
274
super ().__init__ ()
230
275
self .config = config
231
- self .model = BaiChuanModel (config )
276
+ self .model = BaiChuanModel (config , position_embedding )
232
277
self .lm_head = ColumnParallelLinear (config .hidden_size ,
233
278
config .vocab_size ,
234
279
bias = False ,
@@ -318,3 +363,15 @@ def load_weights(self,
318
363
self ._row_parallel_weights ,
319
364
tp_rank ,
320
365
)
366
+
367
+
368
+ class BaichuanForCausalLM (BaiChuanBaseForCausalLM ): # baichuan 13b
369
+
370
+ def __init__ (self , config ):
371
+ super ().__init__ (config , "ALIBI" )
372
+
373
+
374
+ class BaiChuanForCausalLM (BaiChuanBaseForCausalLM ): # baichuan 7b
375
+
376
+ def __init__ (self , config ):
377
+ super ().__init__ (config , "ROPE" )
0 commit comments