Skip to content

Commit 298695b

Browse files
authored
GPTBigCode (StarCoder, SantaCoder Support) (#209)
1 parent 83658c8 commit 298695b

File tree

4 files changed

+298
-2
lines changed

4 files changed

+298
-2
lines changed

vllm/model_executor/layers/activation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"gelu": nn.GELU(),
99
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
1010
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
11+
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
1112
"relu": nn.ReLU(),
1213
}
1314

vllm/model_executor/model_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from transformers import PretrainedConfig
77

88
from vllm.config import ModelConfig
9-
from vllm.model_executor.models import (GPT2LMHeadModel, GPTNeoXForCausalLM,
9+
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM,
1010
LlamaForCausalLM, OPTForCausalLM)
1111
from vllm.model_executor.weight_utils import initialize_dummy_weights
1212

1313
# TODO(woosuk): Lazy-load the model classes.
1414
_MODEL_REGISTRY = {
1515
"GPT2LMHeadModel": GPT2LMHeadModel,
16+
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
1617
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
1718
"LlamaForCausalLM": LlamaForCausalLM,
1819
"OPTForCausalLM": OPTForCausalLM,

vllm/model_executor/models/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
21
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
2+
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
3+
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
34
from vllm.model_executor.models.llama import LlamaForCausalLM
45
from vllm.model_executor.models.opt import OPTForCausalLM
56

67

8+
79
__all__ = [
810
"GPT2LMHeadModel",
11+
"GPTBigCodeForCausalLM",
912
"GPTNeoXForCausalLM",
1013
"LlamaForCausalLM",
1114
"OPTForCausalLM",
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# coding=utf-8
2+
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
3+
# Copyright 2023 The vLLM team.
4+
# Copyright 2023 CTranslate2, and Michael Feil
5+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
6+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
"""Inference-only GPTBigCode model compatible with HuggingFace weights.
20+
21+
The input of the model is flattened to a 1D tensor of tokens. The model uses
22+
InputMetadata to extract the original 2D shape of the input.
23+
"""
24+
from typing import Dict, List, Optional, Tuple
25+
26+
import torch
27+
from torch import nn
28+
import numpy as np
29+
from transformers import GPTBigCodeConfig
30+
31+
from vllm.model_executor.input_metadata import InputMetadata
32+
from vllm.model_executor.layers.activation import get_act_fn
33+
from vllm.model_executor.layers.attention import PagedAttention
34+
from vllm.model_executor.layers.sampler import Sampler
35+
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
36+
load_tensor_parallel_weights)
37+
from vllm.model_executor.parallel_utils.parallel_state import (
38+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
39+
from vllm.model_executor.parallel_utils.tensor_parallel import (
40+
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
41+
from vllm.sequence import SequenceOutputs
42+
43+
KVCache = Tuple[torch.Tensor, torch.Tensor]
44+
45+
46+
class GPTBigCodeAttention(nn.Module):
47+
48+
def __init__(self, config: GPTBigCodeConfig):
49+
super().__init__()
50+
self.hidden_size = config.hidden_size
51+
total_num_heads = config.num_attention_heads
52+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
53+
assert total_num_heads % tensor_model_parallel_world_size == 0
54+
self.num_heads = total_num_heads // tensor_model_parallel_world_size
55+
self.head_dim = self.hidden_size // total_num_heads
56+
self.scale = self.head_dim ** -0.5
57+
58+
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
59+
bias=True, gather_output=False,
60+
perform_initialization=False)
61+
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
62+
bias=True, input_is_parallel=True,
63+
perform_initialization=False)
64+
self.attn = PagedAttention(self.num_heads, self.head_dim,
65+
scale=self.scale)
66+
67+
def forward(
68+
self,
69+
hidden_states: torch.Tensor,
70+
kv_cache: KVCache,
71+
input_metadata: InputMetadata,
72+
cache_event: Optional[torch.cuda.Event],
73+
) -> torch.Tensor:
74+
qkv, _ = self.c_attn(hidden_states)
75+
q, k, v = qkv.chunk(chunks=3, dim=-1)
76+
key_cache, value_cache = kv_cache
77+
attn_output = self.attn(
78+
q, k, v, key_cache, value_cache, input_metadata, cache_event)
79+
attn_output, _ = self.c_proj(attn_output)
80+
return attn_output
81+
82+
83+
class GPTBigMLP(nn.Module):
84+
85+
def __init__(
86+
self,
87+
intermediate_size: int,
88+
config: GPTBigCodeConfig,
89+
):
90+
super().__init__()
91+
hidden_size = config.hidden_size
92+
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
93+
bias=True, gather_output=False,
94+
perform_initialization=False)
95+
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
96+
bias=True, input_is_parallel=True,
97+
perform_initialization=False)
98+
self.act = get_act_fn(config.activation_function)
99+
100+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
101+
hidden_states, _ = self.c_fc(hidden_states)
102+
hidden_states = self.act(hidden_states)
103+
hidden_states, _ = self.c_proj(hidden_states)
104+
return hidden_states
105+
106+
107+
class GPTBigCodeBlock(nn.Module):
108+
109+
def __init__(self, config: GPTBigCodeConfig):
110+
super().__init__()
111+
hidden_size = config.hidden_size
112+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
113+
114+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
115+
self.attn = GPTBigCodeAttention(config)
116+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
117+
self.mlp = GPTBigMLP(inner_dim, config)
118+
119+
def forward(
120+
self,
121+
hidden_states: torch.Tensor,
122+
kv_cache: KVCache,
123+
input_metadata: InputMetadata,
124+
cache_event: Optional[torch.cuda.Event],
125+
) -> torch.Tensor:
126+
residual = hidden_states
127+
hidden_states = self.ln_1(hidden_states)
128+
attn_output = self.attn(
129+
hidden_states=hidden_states,
130+
kv_cache=kv_cache,
131+
input_metadata=input_metadata,
132+
cache_event=cache_event,
133+
)
134+
# residual connection
135+
hidden_states = attn_output + residual
136+
137+
residual = hidden_states
138+
hidden_states = self.ln_2(hidden_states)
139+
feed_forward_hidden_states = self.mlp(hidden_states)
140+
# residual connection
141+
hidden_states = residual + feed_forward_hidden_states
142+
return hidden_states
143+
144+
145+
class GPTBigCodeModel(nn.Module):
146+
147+
def __init__(self, config: GPTBigCodeConfig):
148+
super().__init__()
149+
self.config = config
150+
assert config.add_cross_attention == False
151+
152+
self.embed_dim = config.hidden_size
153+
154+
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
155+
# to 50304 in order to make it divisible by 64.
156+
# This improves performance since GPUs are faster if the dimension
157+
# is divisible by 64. In addition, it allows us to shard the embedding
158+
# layer across 2, 4, 8, or more GPUs.
159+
vocab_size = ((config.vocab_size + 63) // 64) * 64
160+
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
161+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
162+
self.h = nn.ModuleList(
163+
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
164+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
165+
166+
def forward(
167+
self,
168+
input_ids: torch.Tensor,
169+
position_ids: torch.Tensor,
170+
kv_caches: List[KVCache],
171+
input_metadata: InputMetadata,
172+
cache_events: Optional[List[torch.cuda.Event]],
173+
) -> torch.Tensor:
174+
inputs_embeds = self.wte(input_ids)
175+
position_embeds = self.wpe(position_ids)
176+
hidden_states = inputs_embeds + position_embeds
177+
178+
for i in range(len(self.h)):
179+
if cache_events is None:
180+
cache_event = None
181+
else:
182+
cache_event = cache_events[i]
183+
layer = self.h[i]
184+
hidden_states = layer(
185+
hidden_states, kv_caches[i], input_metadata, cache_event)
186+
187+
hidden_states = self.ln_f(hidden_states)
188+
return hidden_states
189+
190+
191+
class GPTBigCodeForCausalLM(nn.Module):
192+
193+
def __init__(self, config: GPTBigCodeConfig):
194+
super().__init__()
195+
self.config = config
196+
self.transformer = GPTBigCodeModel(config)
197+
# TODO(zhuohan): create a new weight after implementing pipeline
198+
# parallelism
199+
self.lm_head_weight = self.transformer.wte.weight
200+
self.sampler = Sampler(config.vocab_size)
201+
202+
def forward(
203+
self,
204+
input_ids: torch.Tensor,
205+
positions: torch.Tensor,
206+
kv_caches: List[KVCache],
207+
input_metadata: InputMetadata,
208+
cache_events: Optional[List[torch.cuda.Event]],
209+
) -> Dict[int, SequenceOutputs]:
210+
hidden_states = self.transformer(
211+
input_ids, positions, kv_caches, input_metadata, cache_events)
212+
next_tokens = self.sampler(
213+
self.lm_head_weight, hidden_states, input_metadata)
214+
return next_tokens
215+
216+
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
217+
_row_parallel_weights = ["c_proj.weight"]
218+
219+
def load_weights(self, model_name_or_path: str,
220+
cache_dir: Optional[str] = None,
221+
use_np_cache: bool = False):
222+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
223+
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
224+
state_dict = self.state_dict()
225+
226+
for name, loaded_weight in hf_model_weights_iterator(
227+
model_name_or_path, cache_dir, use_np_cache):
228+
if "lm_head.weight" in name:
229+
# GPT-2 ties the weights of the embedding layer and the final
230+
# linear layer.
231+
continue
232+
if ".attn.bias" in name:
233+
# Skip attention mask.
234+
# NOTE: "c_attn.bias" should not be skipped.
235+
continue
236+
237+
param = state_dict[name]
238+
239+
def _expand_mqa_mha(qkv_array, n_head, head_dim):
240+
"""manipulates along axis=0 from MQA to MHA
241+
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
242+
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
243+
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
244+
245+
TODO: this function is no longer needed once vllm supports MQA.
246+
"""
247+
qkv_array = qkv_array.numpy()
248+
249+
dims_q = n_head * head_dim
250+
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
251+
# q is fine, but k & v have not replicated shape along the first axis
252+
# as long as MQA is not nativly supported, increase memory and replicated
253+
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim)
254+
if k.ndim == 2 and v.ndim == 2:
255+
replication = (n_head, 1) # weights
256+
else:
257+
replication = n_head # biases
258+
# replicate n_head times for q, v
259+
k, v = np.tile(k, replication), np.tile(v, replication)
260+
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim)
261+
# to (3 * n_heads * head_dim, hidden_dim)
262+
qkv_array = np.concatenate((q, k, v), axis=0)
263+
return torch.from_numpy(qkv_array)
264+
265+
# For the fused QKV linear layer, manually shard the weights.
266+
if "c_attn" in name:
267+
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
268+
# When tensor parallelism is used, we shard the weights along the head dimension.
269+
total_num_heads = self.config.num_attention_heads
270+
hidden_size = self.config.hidden_size
271+
head_size = hidden_size // total_num_heads
272+
num_heads = total_num_heads // tensor_model_parallel_world_size
273+
head_start = tensor_model_parallel_rank * num_heads
274+
head_end = (tensor_model_parallel_rank + 1) * num_heads
275+
276+
if name.endswith(".weight"):
277+
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
278+
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
279+
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
280+
loaded_weight = loaded_weight.reshape(-1, hidden_size)
281+
elif name.endswith(".bias"):
282+
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
283+
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
284+
loaded_weight = loaded_weight[:, head_start:head_end, :]
285+
loaded_weight = loaded_weight.reshape(-1)
286+
else:
287+
raise ValueError(f"Unexpected parameter name {name}")
288+
load_tensor_parallel_weights(param, loaded_weight, name,
289+
self._column_parallel_weights,
290+
self._row_parallel_weights,
291+
tensor_model_parallel_rank)

0 commit comments

Comments
 (0)