|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. |
| 3 | +# |
| 4 | +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX |
| 5 | +# and OPT implementations in this library. It has been modified from its |
| 6 | +# original forms to accommodate minor architectural differences compared |
| 7 | +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. |
| 8 | +# |
| 9 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 10 | +# you may not use this file except in compliance with the License. |
| 11 | +# You may obtain a copy of the License at |
| 12 | +# |
| 13 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 14 | +# |
| 15 | +# Unless required by applicable law or agreed to in writing, software |
| 16 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 17 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 18 | +# See the License for the specific language governing permissions and |
| 19 | +# limitations under the License. |
| 20 | +""" PyTorch Starcoder2 model.""" |
| 21 | +from typing import List, Optional, Tuple |
| 22 | + |
| 23 | +import torch |
| 24 | +from torch import nn |
| 25 | + |
| 26 | +from vllm.model_executor.input_metadata import InputMetadata |
| 27 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 28 | +from vllm.model_executor.layers.attention import PagedAttention |
| 29 | +from vllm.model_executor.layers.activation import get_act_fn |
| 30 | +from vllm.model_executor.layers.rotary_embedding import get_rope |
| 31 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 32 | + LinearMethodBase, |
| 33 | + QKVParallelLinear, |
| 34 | + RowParallelLinear) |
| 35 | +from vllm.model_executor.layers.sampler import Sampler |
| 36 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 37 | + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) |
| 38 | +from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size |
| 39 | +from vllm.model_executor.weight_utils import (default_weight_loader, |
| 40 | + hf_model_weights_iterator) |
| 41 | +from vllm.sequence import SamplerOutput |
| 42 | + |
| 43 | +try: |
| 44 | + from transformers import Starcoder2Config |
| 45 | +except ImportError: |
| 46 | + # fallback to PretrainedConfig |
| 47 | + # NOTE: Please install transformers from source or use transformers>=4.39.0 |
| 48 | + from transformers import PretrainedConfig as Starcoder2Config |
| 49 | + |
| 50 | +KVCache = Tuple[torch.Tensor, torch.Tensor] |
| 51 | + |
| 52 | + |
| 53 | +class Starcoder2Attention(nn.Module): |
| 54 | + |
| 55 | + def __init__(self, |
| 56 | + config: Starcoder2Config, |
| 57 | + linear_method: Optional[LinearMethodBase] = None): |
| 58 | + super().__init__() |
| 59 | + self.config = config |
| 60 | + |
| 61 | + self.hidden_size = config.hidden_size |
| 62 | + tp_size = get_tensor_model_parallel_world_size() |
| 63 | + self.total_num_heads = config.num_attention_heads |
| 64 | + assert self.total_num_heads % tp_size == 0 |
| 65 | + self.num_heads = self.total_num_heads // tp_size |
| 66 | + self.total_num_kv_heads = config.num_key_value_heads |
| 67 | + if self.total_num_kv_heads >= tp_size: |
| 68 | + # Number of KV heads is greater than TP size, so we partition |
| 69 | + # the KV heads across multiple tensor parallel GPUs. |
| 70 | + assert self.total_num_kv_heads % tp_size == 0 |
| 71 | + else: |
| 72 | + # Number of KV heads is less than TP size, so we replicate |
| 73 | + # the KV heads across multiple tensor parallel GPUs. |
| 74 | + assert tp_size % self.total_num_kv_heads == 0 |
| 75 | + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) |
| 76 | + self.head_dim = self.hidden_size // self.total_num_heads |
| 77 | + self.q_size = self.num_heads * self.head_dim |
| 78 | + self.kv_size = self.num_kv_heads * self.head_dim |
| 79 | + self.scaling = self.head_dim**-0.5 |
| 80 | + self.rope_theta = config.rope_theta |
| 81 | + self.max_position_embeddings = config.max_position_embeddings |
| 82 | + self.use_bias = config.use_bias |
| 83 | + self.sliding_window = config.sliding_window |
| 84 | + |
| 85 | + self.qkv_proj = QKVParallelLinear( |
| 86 | + self.hidden_size, |
| 87 | + self.head_dim, |
| 88 | + self.total_num_heads, |
| 89 | + self.total_num_kv_heads, |
| 90 | + bias=self.use_bias, |
| 91 | + linear_method=linear_method, |
| 92 | + ) |
| 93 | + self.o_proj = RowParallelLinear( |
| 94 | + self.total_num_heads * self.head_dim, |
| 95 | + self.hidden_size, |
| 96 | + bias=self.use_bias, |
| 97 | + linear_method=linear_method, |
| 98 | + ) |
| 99 | + self.rotary_emb = get_rope( |
| 100 | + self.head_dim, |
| 101 | + rotary_dim=self.head_dim, |
| 102 | + max_position=self.max_position_embeddings, |
| 103 | + base=int(self.rope_theta), |
| 104 | + is_neox_style=True, |
| 105 | + ) |
| 106 | + self.attn = PagedAttention( |
| 107 | + self.num_heads, |
| 108 | + self.head_dim, |
| 109 | + self.scaling, |
| 110 | + num_kv_heads=self.num_kv_heads, |
| 111 | + sliding_window=self.sliding_window, |
| 112 | + ) |
| 113 | + |
| 114 | + def forward( |
| 115 | + self, |
| 116 | + positions: torch.Tensor, |
| 117 | + hidden_states: torch.Tensor, |
| 118 | + kv_cache: KVCache, |
| 119 | + input_metadata: InputMetadata, |
| 120 | + ) -> torch.Tensor: |
| 121 | + qkv, _ = self.qkv_proj(hidden_states) |
| 122 | + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
| 123 | + q, k = self.rotary_emb(positions, q, k) |
| 124 | + k_cache, v_cache = kv_cache |
| 125 | + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) |
| 126 | + output, _ = self.o_proj(attn_output) |
| 127 | + return output |
| 128 | + |
| 129 | + |
| 130 | +class Starcoder2MLP(nn.Module): |
| 131 | + |
| 132 | + def __init__(self, |
| 133 | + config: Starcoder2Config, |
| 134 | + linear_method: Optional[LinearMethodBase] = None): |
| 135 | + super().__init__() |
| 136 | + self.c_fc = ColumnParallelLinear( |
| 137 | + config.hidden_size, |
| 138 | + config.intermediate_size, |
| 139 | + bias=config.use_bias, |
| 140 | + linear_method=linear_method, |
| 141 | + ) |
| 142 | + self.c_proj = RowParallelLinear( |
| 143 | + config.intermediate_size, |
| 144 | + config.hidden_size, |
| 145 | + bias=config.use_bias, |
| 146 | + linear_method=linear_method, |
| 147 | + ) |
| 148 | + self.act = get_act_fn(config.hidden_act, |
| 149 | + intermediate_size=config.intermediate_size) |
| 150 | + |
| 151 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 152 | + hidden_states, _ = self.c_fc(hidden_states) |
| 153 | + hidden_states = self.act(hidden_states) |
| 154 | + hidden_states, _ = self.c_proj(hidden_states) |
| 155 | + return hidden_states |
| 156 | + |
| 157 | + |
| 158 | +class Starcoder2DecoderLayer(nn.Module): |
| 159 | + |
| 160 | + def __init__(self, |
| 161 | + config: Starcoder2Config, |
| 162 | + linear_method: Optional[LinearMethodBase] = None): |
| 163 | + super().__init__() |
| 164 | + self.hidden_size = config.hidden_size |
| 165 | + self.self_attn = Starcoder2Attention(config, |
| 166 | + linear_method=linear_method) |
| 167 | + self.mlp = Starcoder2MLP(config, linear_method=linear_method) |
| 168 | + self.input_layernorm = nn.LayerNorm(config.hidden_size, |
| 169 | + eps=config.norm_epsilon) |
| 170 | + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, |
| 171 | + eps=config.norm_epsilon) |
| 172 | + |
| 173 | + def forward( |
| 174 | + self, |
| 175 | + positions: torch.Tensor, |
| 176 | + hidden_states: torch.Tensor, |
| 177 | + kv_cache: KVCache, |
| 178 | + input_metadata: InputMetadata, |
| 179 | + ) -> torch.Tensor: |
| 180 | + # Self Attention |
| 181 | + residual = hidden_states |
| 182 | + hidden_states = self.input_layernorm(hidden_states) |
| 183 | + hidden_states = self.self_attn( |
| 184 | + positions=positions, |
| 185 | + hidden_states=hidden_states, |
| 186 | + kv_cache=kv_cache, |
| 187 | + input_metadata=input_metadata, |
| 188 | + ) |
| 189 | + hidden_states = residual + hidden_states |
| 190 | + |
| 191 | + # Fully Connected |
| 192 | + residual = hidden_states |
| 193 | + hidden_states = self.post_attention_layernorm(hidden_states) |
| 194 | + hidden_states = self.mlp(hidden_states) |
| 195 | + hidden_states = residual + hidden_states |
| 196 | + |
| 197 | + return hidden_states |
| 198 | + |
| 199 | + |
| 200 | +class Starcoder2Model(nn.Module): |
| 201 | + |
| 202 | + def __init__(self, |
| 203 | + config: Starcoder2Config, |
| 204 | + linear_method: Optional[LinearMethodBase] = None): |
| 205 | + super().__init__() |
| 206 | + self.config = config |
| 207 | + self.padding_idx = config.pad_token_id |
| 208 | + self.vocab_size = config.vocab_size |
| 209 | + |
| 210 | + # TODO: consider padding_idx (currently removed) |
| 211 | + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, |
| 212 | + config.hidden_size) |
| 213 | + self.layers = nn.ModuleList([ |
| 214 | + Starcoder2DecoderLayer(config, linear_method=linear_method) |
| 215 | + for _ in range(config.num_hidden_layers) |
| 216 | + ]) |
| 217 | + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) |
| 218 | + |
| 219 | + def forward( |
| 220 | + self, |
| 221 | + input_ids: torch.Tensor, |
| 222 | + positions: torch.Tensor, |
| 223 | + kv_caches: List[KVCache], |
| 224 | + input_metadata: InputMetadata, |
| 225 | + ) -> torch.Tensor: |
| 226 | + hidden_states = self.embed_tokens(input_ids) |
| 227 | + for i in range(len(self.layers)): |
| 228 | + layer = self.layers[i] |
| 229 | + hidden_states = layer(positions, hidden_states, kv_caches[i], |
| 230 | + input_metadata) |
| 231 | + hidden_states = self.norm(hidden_states) |
| 232 | + return hidden_states |
| 233 | + |
| 234 | + |
| 235 | +class Starcoder2ForCausalLM(nn.Module): |
| 236 | + |
| 237 | + def __init__(self, |
| 238 | + config: Starcoder2Config, |
| 239 | + linear_method: Optional[LinearMethodBase] = None): |
| 240 | + super().__init__() |
| 241 | + self.config = config |
| 242 | + self.model = Starcoder2Model(config, linear_method=linear_method) |
| 243 | + self.vocab_size = config.vocab_size |
| 244 | + self.unpadded_vocab_size = config.vocab_size |
| 245 | + if config.tie_word_embeddings: |
| 246 | + self.lm_head_weight = self.model.embed_tokens.weight |
| 247 | + else: |
| 248 | + self.unpadded_vocab_size = config.vocab_size |
| 249 | + self.lm_head = ParallelLMHead( |
| 250 | + self.unpadded_vocab_size, |
| 251 | + config.hidden_size, |
| 252 | + org_num_embeddings=config.vocab_size, |
| 253 | + padding_size=DEFAULT_VOCAB_PADDING_SIZE, |
| 254 | + ) |
| 255 | + self.lm_head_weight = self.lm_head.weight |
| 256 | + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) |
| 257 | + |
| 258 | + def forward( |
| 259 | + self, |
| 260 | + input_ids: torch.Tensor, |
| 261 | + positions: torch.Tensor, |
| 262 | + kv_caches: List[KVCache], |
| 263 | + input_metadata: InputMetadata, |
| 264 | + ) -> torch.Tensor: |
| 265 | + hidden_states = self.model(input_ids, positions, kv_caches, |
| 266 | + input_metadata) |
| 267 | + return hidden_states |
| 268 | + |
| 269 | + def sample( |
| 270 | + self, |
| 271 | + hidden_states: Optional[torch.Tensor], |
| 272 | + sampling_metadata: SamplingMetadata, |
| 273 | + ) -> Optional[SamplerOutput]: |
| 274 | + next_tokens = self.sampler(self.lm_head_weight, hidden_states, |
| 275 | + sampling_metadata) |
| 276 | + return next_tokens |
| 277 | + |
| 278 | + def load_weights(self, |
| 279 | + model_name_or_path: str, |
| 280 | + cache_dir: Optional[str] = None, |
| 281 | + load_format: str = "auto", |
| 282 | + revision: Optional[str] = None): |
| 283 | + stacked_params_mapping = [ |
| 284 | + # (param_name, shard_name, shard_id) |
| 285 | + ("qkv_proj", "q_proj", "q"), |
| 286 | + ("qkv_proj", "k_proj", "k"), |
| 287 | + ("qkv_proj", "v_proj", "v"), |
| 288 | + ] |
| 289 | + |
| 290 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 291 | + for name, loaded_weight in hf_model_weights_iterator( |
| 292 | + model_name_or_path, cache_dir, load_format, revision): |
| 293 | + if "rotary_emb.inv_freq" in name: |
| 294 | + continue |
| 295 | + |
| 296 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 297 | + if weight_name not in name: |
| 298 | + continue |
| 299 | + name = name.replace(weight_name, param_name) |
| 300 | + param = params_dict[name] |
| 301 | + weight_loader = param.weight_loader |
| 302 | + weight_loader(param, loaded_weight, shard_id) |
| 303 | + break |
| 304 | + else: |
| 305 | + if self.config.tie_word_embeddings and "lm_head.weight" in name: |
| 306 | + continue |
| 307 | + param = params_dict[name] |
| 308 | + weight_loader = getattr(param, "weight_loader", |
| 309 | + default_weight_loader) |
| 310 | + weight_loader(param, loaded_weight) |
0 commit comments