Skip to content

Commit 64f23c2

Browse files
authored
fix baichuan for different position embedding for 7b and 13b models (#643)
1 parent d4c7755 commit 64f23c2

File tree

3 files changed

+76
-17
lines changed

3 files changed

+76
-17
lines changed

vllm/model_executor/model_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
# TODO(woosuk): Lazy-load the model classes.
1313
_MODEL_REGISTRY = {
14-
"BaiChuanForCausalLM": BaiChuanForCausalLM,
14+
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
15+
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
1516
"BloomForCausalLM": BloomForCausalLM,
1617
"GPT2LMHeadModel": GPT2LMHeadModel,
1718
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM
1+
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
22
from vllm.model_executor.models.bloom import BloomForCausalLM
33
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
44
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
@@ -10,6 +10,7 @@
1010

1111
__all__ = [
1212
"BaiChuanForCausalLM",
13+
"BaichuanForCausalLM",
1314
"BloomForCausalLM",
1415
"GPT2LMHeadModel",
1516
"GPTBigCodeForCausalLM",

vllm/model_executor/models/baichuan.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
The input of the model is flattened to a 1D tensor of tokens. The model uses
2323
InputMetadata to extract the original 2D shape of the input.
2424
"""
25+
import math
2526
from typing import Dict, List, Optional, Tuple
2627

2728
import torch
@@ -31,7 +32,7 @@
3132
from vllm.model_executor.input_metadata import InputMetadata
3233
from vllm.model_executor.layers.activation import SiluAndMul
3334
from vllm.model_executor.layers.layernorm import RMSNorm
34-
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
35+
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
3536
from vllm.model_executor.layers.sampler import Sampler
3637
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
3738
load_tensor_parallel_weights)
@@ -44,6 +45,31 @@
4445
KVCache = Tuple[torch.Tensor, torch.Tensor]
4546

4647

48+
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
49+
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
50+
base = torch.tensor(
51+
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
52+
dtype=torch.float32,
53+
)
54+
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
55+
slopes = torch.pow(base, powers)
56+
57+
if closest_power_of_2 != total_num_heads:
58+
extra_base = torch.tensor(
59+
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
60+
dtype=torch.float32,
61+
)
62+
num_remaining_heads = min(closest_power_of_2,
63+
total_num_heads - closest_power_of_2)
64+
extra_powers = torch.arange(start=1,
65+
end=1 + 2 * num_remaining_heads,
66+
step=2,
67+
dtype=torch.int32)
68+
slopes = torch.cat(
69+
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
70+
return slopes
71+
72+
4773
class BaiChuanMLP(nn.Module):
4874

4975
def __init__(
@@ -82,6 +108,7 @@ def __init__(
82108
self,
83109
hidden_size: int,
84110
num_heads: int,
111+
position_embedding: str,
85112
):
86113
super().__init__()
87114
self.hidden_size = hidden_size
@@ -92,7 +119,7 @@ def __init__(
92119
self.num_heads = (self.total_num_heads //
93120
tensor_model_parallel_world_size)
94121
self.head_dim = hidden_size // self.total_num_heads
95-
self.scaling = self.head_dim**-0.5
122+
self.postion_embedding = position_embedding
96123

97124
# pylint: disable=invalid-name
98125
self.W_pack = ColumnParallelLinear(
@@ -109,11 +136,23 @@ def __init__(
109136
input_is_parallel=True,
110137
perform_initialization=False,
111138
)
112-
113-
self.attn = PagedAttentionWithRoPE(self.num_heads,
114-
self.head_dim,
115-
self.scaling,
116-
rotary_dim=self.head_dim)
139+
# Create the alibi slopes and slice them.
140+
if self.postion_embedding == "ALIBI":
141+
tp_rank = get_tensor_model_parallel_rank()
142+
head_start = tp_rank * self.num_heads
143+
head_end = (tp_rank + 1) * self.num_heads
144+
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
145+
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
146+
147+
scaling = self.head_dim**-0.5
148+
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
149+
scaling, alibi_slopes)
150+
else:
151+
self.scaling = self.head_dim**-0.5
152+
self.attn = PagedAttentionWithRoPE(self.num_heads,
153+
self.head_dim,
154+
self.scaling,
155+
rotary_dim=self.head_dim)
117156

118157
def forward(
119158
self,
@@ -126,20 +165,26 @@ def forward(
126165
qkv, _ = self.W_pack(hidden_states)
127166
q, k, v = qkv.chunk(chunks=3, dim=-1)
128167
k_cache, v_cache = kv_cache
129-
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
130-
input_metadata, cache_event)
168+
if self.postion_embedding == "ALIBI":
169+
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
170+
cache_event)
171+
else:
172+
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
173+
input_metadata, cache_event)
174+
131175
output, _ = self.o_proj(attn_output)
132176
return output
133177

134178

135179
class BaiChuanDecoderLayer(nn.Module):
136180

137-
def __init__(self, config: BaiChuanConfig):
181+
def __init__(self, config: BaiChuanConfig, position_embedding: str):
138182
super().__init__()
139183
self.hidden_size = config.hidden_size
140184
self.self_attn = BaiChuanAttention(
141185
hidden_size=self.hidden_size,
142186
num_heads=config.num_attention_heads,
187+
position_embedding=position_embedding,
143188
)
144189
self.mlp = BaiChuanMLP(
145190
hidden_size=self.hidden_size,
@@ -181,7 +226,7 @@ def forward(
181226

182227
class BaiChuanModel(nn.Module):
183228

184-
def __init__(self, config: BaiChuanConfig):
229+
def __init__(self, config: BaiChuanConfig, position_embedding: str):
185230
super().__init__()
186231
self.config = config
187232
self.padding_idx = config.pad_token_id
@@ -192,7 +237,7 @@ def __init__(self, config: BaiChuanConfig):
192237
config.hidden_size,
193238
perform_initialization=False)
194239
self.layers = nn.ModuleList([
195-
BaiChuanDecoderLayer(config)
240+
BaiChuanDecoderLayer(config, position_embedding)
196241
for _ in range(config.num_hidden_layers)
197242
])
198243
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -223,12 +268,12 @@ def forward(
223268
return hidden_states
224269

225270

226-
class BaiChuanForCausalLM(nn.Module):
271+
class BaiChuanBaseForCausalLM(nn.Module):
227272

228-
def __init__(self, config):
273+
def __init__(self, config, position_embedding: str):
229274
super().__init__()
230275
self.config = config
231-
self.model = BaiChuanModel(config)
276+
self.model = BaiChuanModel(config, position_embedding)
232277
self.lm_head = ColumnParallelLinear(config.hidden_size,
233278
config.vocab_size,
234279
bias=False,
@@ -318,3 +363,15 @@ def load_weights(self,
318363
self._row_parallel_weights,
319364
tp_rank,
320365
)
366+
367+
368+
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
369+
370+
def __init__(self, config):
371+
super().__init__(config, "ALIBI")
372+
373+
374+
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
375+
376+
def __init__(self, config):
377+
super().__init__(config, "ROPE")

0 commit comments

Comments
 (0)