Skip to content

Commit bfdcfa6

Browse files
authored
Support starcoder2 architecture (#3089)
1 parent 9289e57 commit bfdcfa6

File tree

7 files changed

+452
-0
lines changed

7 files changed

+452
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
7878
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
7979
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
8080
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
81+
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
8182
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
8283

8384
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):

tests/models/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"microsoft/phi-2",
2020
"stabilityai/stablelm-3b-4e1t",
2121
"allenai/OLMo-1B",
22+
"bigcode/starcoder2-3b",
2223
]
2324

2425

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
4646
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
4747
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
48+
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
4849
}
4950

5051
# Models not supported by ROCm.
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# coding=utf-8
2+
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5+
# and OPT implementations in this library. It has been modified from its
6+
# original forms to accommodate minor architectural differences compared
7+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
""" PyTorch Starcoder2 model."""
21+
from typing import List, Optional, Tuple
22+
23+
import torch
24+
from torch import nn
25+
26+
from vllm.model_executor.input_metadata import InputMetadata
27+
from vllm.model_executor.sampling_metadata import SamplingMetadata
28+
from vllm.model_executor.layers.attention import PagedAttention
29+
from vllm.model_executor.layers.activation import get_act_fn
30+
from vllm.model_executor.layers.rotary_embedding import get_rope
31+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
32+
LinearMethodBase,
33+
QKVParallelLinear,
34+
RowParallelLinear)
35+
from vllm.model_executor.layers.sampler import Sampler
36+
from vllm.model_executor.layers.vocab_parallel_embedding import (
37+
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
38+
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size
39+
from vllm.model_executor.weight_utils import (default_weight_loader,
40+
hf_model_weights_iterator)
41+
from vllm.sequence import SamplerOutput
42+
43+
try:
44+
from transformers import Starcoder2Config
45+
except ImportError:
46+
# fallback to PretrainedConfig
47+
# NOTE: Please install transformers from source or use transformers>=4.39.0
48+
from transformers import PretrainedConfig as Starcoder2Config
49+
50+
KVCache = Tuple[torch.Tensor, torch.Tensor]
51+
52+
53+
class Starcoder2Attention(nn.Module):
54+
55+
def __init__(self,
56+
config: Starcoder2Config,
57+
linear_method: Optional[LinearMethodBase] = None):
58+
super().__init__()
59+
self.config = config
60+
61+
self.hidden_size = config.hidden_size
62+
tp_size = get_tensor_model_parallel_world_size()
63+
self.total_num_heads = config.num_attention_heads
64+
assert self.total_num_heads % tp_size == 0
65+
self.num_heads = self.total_num_heads // tp_size
66+
self.total_num_kv_heads = config.num_key_value_heads
67+
if self.total_num_kv_heads >= tp_size:
68+
# Number of KV heads is greater than TP size, so we partition
69+
# the KV heads across multiple tensor parallel GPUs.
70+
assert self.total_num_kv_heads % tp_size == 0
71+
else:
72+
# Number of KV heads is less than TP size, so we replicate
73+
# the KV heads across multiple tensor parallel GPUs.
74+
assert tp_size % self.total_num_kv_heads == 0
75+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
76+
self.head_dim = self.hidden_size // self.total_num_heads
77+
self.q_size = self.num_heads * self.head_dim
78+
self.kv_size = self.num_kv_heads * self.head_dim
79+
self.scaling = self.head_dim**-0.5
80+
self.rope_theta = config.rope_theta
81+
self.max_position_embeddings = config.max_position_embeddings
82+
self.use_bias = config.use_bias
83+
self.sliding_window = config.sliding_window
84+
85+
self.qkv_proj = QKVParallelLinear(
86+
self.hidden_size,
87+
self.head_dim,
88+
self.total_num_heads,
89+
self.total_num_kv_heads,
90+
bias=self.use_bias,
91+
linear_method=linear_method,
92+
)
93+
self.o_proj = RowParallelLinear(
94+
self.total_num_heads * self.head_dim,
95+
self.hidden_size,
96+
bias=self.use_bias,
97+
linear_method=linear_method,
98+
)
99+
self.rotary_emb = get_rope(
100+
self.head_dim,
101+
rotary_dim=self.head_dim,
102+
max_position=self.max_position_embeddings,
103+
base=int(self.rope_theta),
104+
is_neox_style=True,
105+
)
106+
self.attn = PagedAttention(
107+
self.num_heads,
108+
self.head_dim,
109+
self.scaling,
110+
num_kv_heads=self.num_kv_heads,
111+
sliding_window=self.sliding_window,
112+
)
113+
114+
def forward(
115+
self,
116+
positions: torch.Tensor,
117+
hidden_states: torch.Tensor,
118+
kv_cache: KVCache,
119+
input_metadata: InputMetadata,
120+
) -> torch.Tensor:
121+
qkv, _ = self.qkv_proj(hidden_states)
122+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
123+
q, k = self.rotary_emb(positions, q, k)
124+
k_cache, v_cache = kv_cache
125+
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
126+
output, _ = self.o_proj(attn_output)
127+
return output
128+
129+
130+
class Starcoder2MLP(nn.Module):
131+
132+
def __init__(self,
133+
config: Starcoder2Config,
134+
linear_method: Optional[LinearMethodBase] = None):
135+
super().__init__()
136+
self.c_fc = ColumnParallelLinear(
137+
config.hidden_size,
138+
config.intermediate_size,
139+
bias=config.use_bias,
140+
linear_method=linear_method,
141+
)
142+
self.c_proj = RowParallelLinear(
143+
config.intermediate_size,
144+
config.hidden_size,
145+
bias=config.use_bias,
146+
linear_method=linear_method,
147+
)
148+
self.act = get_act_fn(config.hidden_act,
149+
intermediate_size=config.intermediate_size)
150+
151+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
152+
hidden_states, _ = self.c_fc(hidden_states)
153+
hidden_states = self.act(hidden_states)
154+
hidden_states, _ = self.c_proj(hidden_states)
155+
return hidden_states
156+
157+
158+
class Starcoder2DecoderLayer(nn.Module):
159+
160+
def __init__(self,
161+
config: Starcoder2Config,
162+
linear_method: Optional[LinearMethodBase] = None):
163+
super().__init__()
164+
self.hidden_size = config.hidden_size
165+
self.self_attn = Starcoder2Attention(config,
166+
linear_method=linear_method)
167+
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
168+
self.input_layernorm = nn.LayerNorm(config.hidden_size,
169+
eps=config.norm_epsilon)
170+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
171+
eps=config.norm_epsilon)
172+
173+
def forward(
174+
self,
175+
positions: torch.Tensor,
176+
hidden_states: torch.Tensor,
177+
kv_cache: KVCache,
178+
input_metadata: InputMetadata,
179+
) -> torch.Tensor:
180+
# Self Attention
181+
residual = hidden_states
182+
hidden_states = self.input_layernorm(hidden_states)
183+
hidden_states = self.self_attn(
184+
positions=positions,
185+
hidden_states=hidden_states,
186+
kv_cache=kv_cache,
187+
input_metadata=input_metadata,
188+
)
189+
hidden_states = residual + hidden_states
190+
191+
# Fully Connected
192+
residual = hidden_states
193+
hidden_states = self.post_attention_layernorm(hidden_states)
194+
hidden_states = self.mlp(hidden_states)
195+
hidden_states = residual + hidden_states
196+
197+
return hidden_states
198+
199+
200+
class Starcoder2Model(nn.Module):
201+
202+
def __init__(self,
203+
config: Starcoder2Config,
204+
linear_method: Optional[LinearMethodBase] = None):
205+
super().__init__()
206+
self.config = config
207+
self.padding_idx = config.pad_token_id
208+
self.vocab_size = config.vocab_size
209+
210+
# TODO: consider padding_idx (currently removed)
211+
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
212+
config.hidden_size)
213+
self.layers = nn.ModuleList([
214+
Starcoder2DecoderLayer(config, linear_method=linear_method)
215+
for _ in range(config.num_hidden_layers)
216+
])
217+
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
218+
219+
def forward(
220+
self,
221+
input_ids: torch.Tensor,
222+
positions: torch.Tensor,
223+
kv_caches: List[KVCache],
224+
input_metadata: InputMetadata,
225+
) -> torch.Tensor:
226+
hidden_states = self.embed_tokens(input_ids)
227+
for i in range(len(self.layers)):
228+
layer = self.layers[i]
229+
hidden_states = layer(positions, hidden_states, kv_caches[i],
230+
input_metadata)
231+
hidden_states = self.norm(hidden_states)
232+
return hidden_states
233+
234+
235+
class Starcoder2ForCausalLM(nn.Module):
236+
237+
def __init__(self,
238+
config: Starcoder2Config,
239+
linear_method: Optional[LinearMethodBase] = None):
240+
super().__init__()
241+
self.config = config
242+
self.model = Starcoder2Model(config, linear_method=linear_method)
243+
self.vocab_size = config.vocab_size
244+
self.unpadded_vocab_size = config.vocab_size
245+
if config.tie_word_embeddings:
246+
self.lm_head_weight = self.model.embed_tokens.weight
247+
else:
248+
self.unpadded_vocab_size = config.vocab_size
249+
self.lm_head = ParallelLMHead(
250+
self.unpadded_vocab_size,
251+
config.hidden_size,
252+
org_num_embeddings=config.vocab_size,
253+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
254+
)
255+
self.lm_head_weight = self.lm_head.weight
256+
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
257+
258+
def forward(
259+
self,
260+
input_ids: torch.Tensor,
261+
positions: torch.Tensor,
262+
kv_caches: List[KVCache],
263+
input_metadata: InputMetadata,
264+
) -> torch.Tensor:
265+
hidden_states = self.model(input_ids, positions, kv_caches,
266+
input_metadata)
267+
return hidden_states
268+
269+
def sample(
270+
self,
271+
hidden_states: Optional[torch.Tensor],
272+
sampling_metadata: SamplingMetadata,
273+
) -> Optional[SamplerOutput]:
274+
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
275+
sampling_metadata)
276+
return next_tokens
277+
278+
def load_weights(self,
279+
model_name_or_path: str,
280+
cache_dir: Optional[str] = None,
281+
load_format: str = "auto",
282+
revision: Optional[str] = None):
283+
stacked_params_mapping = [
284+
# (param_name, shard_name, shard_id)
285+
("qkv_proj", "q_proj", "q"),
286+
("qkv_proj", "k_proj", "k"),
287+
("qkv_proj", "v_proj", "v"),
288+
]
289+
290+
params_dict = dict(self.named_parameters(remove_duplicate=False))
291+
for name, loaded_weight in hf_model_weights_iterator(
292+
model_name_or_path, cache_dir, load_format, revision):
293+
if "rotary_emb.inv_freq" in name:
294+
continue
295+
296+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
297+
if weight_name not in name:
298+
continue
299+
name = name.replace(weight_name, param_name)
300+
param = params_dict[name]
301+
weight_loader = param.weight_loader
302+
weight_loader(param, loaded_weight, shard_id)
303+
break
304+
else:
305+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
306+
continue
307+
param = params_dict[name]
308+
weight_loader = getattr(param, "weight_loader",
309+
default_weight_loader)
310+
weight_loader(param, loaded_weight)

vllm/transformers_utils/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,23 @@
99
"mpt": MPTConfig,
1010
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
1111
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
12+
"starcoder2": Starcoder2Config,
1213
}
1314

1415

1516
def get_config(model: str,
1617
trust_remote_code: bool,
1718
revision: Optional[str] = None,
1819
code_revision: Optional[str] = None) -> PretrainedConfig:
20+
# FIXME(woosuk): This is a temporary fix for StarCoder2.
21+
# Remove this when the model is supported by HuggingFace transformers.
22+
if "bigcode" in model and "starcoder2" in model:
23+
config_class = _CONFIG_REGISTRY["starcoder2"]
24+
config = config_class.from_pretrained(model,
25+
revision=revision,
26+
code_revision=code_revision)
27+
return config
28+
1929
try:
2030
config = AutoConfig.from_pretrained(
2131
model,

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
55
# `FalconConfig` class from the official HuggingFace transformers library.
66
from vllm.transformers_utils.configs.falcon import RWConfig
7+
from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config
78

89
__all__ = [
910
"ChatGLMConfig",
1011
"MPTConfig",
1112
"RWConfig",
13+
"Starcoder2Config",
1214
]

0 commit comments

Comments
 (0)