|
| 1 | +# coding=utf-8 |
| 2 | +# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py |
| 3 | +# Copyright 2023 The vLLM team. |
| 4 | +# Copyright 2023 CTranslate2, and Michael Feil |
| 5 | +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. |
| 6 | +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. |
| 7 | +# |
| 8 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 9 | +# you may not use this file except in compliance with the License. |
| 10 | +# You may obtain a copy of the License at |
| 11 | +# |
| 12 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 13 | +# |
| 14 | +# Unless required by applicable law or agreed to in writing, software |
| 15 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 16 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 17 | +# See the License for the specific language governing permissions and |
| 18 | +# limitations under the License. |
| 19 | +"""Inference-only GPTBigCode model compatible with HuggingFace weights. |
| 20 | +
|
| 21 | +The input of the model is flattened to a 1D tensor of tokens. The model uses |
| 22 | +InputMetadata to extract the original 2D shape of the input. |
| 23 | +""" |
| 24 | +from typing import Dict, List, Optional, Tuple |
| 25 | + |
| 26 | +import torch |
| 27 | +from torch import nn |
| 28 | +import numpy as np |
| 29 | +from transformers import GPTBigCodeConfig |
| 30 | + |
| 31 | +from vllm.model_executor.input_metadata import InputMetadata |
| 32 | +from vllm.model_executor.layers.activation import get_act_fn |
| 33 | +from vllm.model_executor.layers.attention import PagedAttention |
| 34 | +from vllm.model_executor.layers.sampler import Sampler |
| 35 | +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, |
| 36 | + load_tensor_parallel_weights) |
| 37 | +from vllm.model_executor.parallel_utils.parallel_state import ( |
| 38 | + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) |
| 39 | +from vllm.model_executor.parallel_utils.tensor_parallel import ( |
| 40 | + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) |
| 41 | +from vllm.sequence import SequenceOutputs |
| 42 | + |
| 43 | +KVCache = Tuple[torch.Tensor, torch.Tensor] |
| 44 | + |
| 45 | + |
| 46 | +class GPTBigCodeAttention(nn.Module): |
| 47 | + |
| 48 | + def __init__(self, config: GPTBigCodeConfig): |
| 49 | + super().__init__() |
| 50 | + self.hidden_size = config.hidden_size |
| 51 | + total_num_heads = config.num_attention_heads |
| 52 | + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() |
| 53 | + assert total_num_heads % tensor_model_parallel_world_size == 0 |
| 54 | + self.num_heads = total_num_heads // tensor_model_parallel_world_size |
| 55 | + self.head_dim = self.hidden_size // total_num_heads |
| 56 | + self.scale = self.head_dim ** -0.5 |
| 57 | + |
| 58 | + self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, |
| 59 | + bias=True, gather_output=False, |
| 60 | + perform_initialization=False) |
| 61 | + self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, |
| 62 | + bias=True, input_is_parallel=True, |
| 63 | + perform_initialization=False) |
| 64 | + self.attn = PagedAttention(self.num_heads, self.head_dim, |
| 65 | + scale=self.scale) |
| 66 | + |
| 67 | + def forward( |
| 68 | + self, |
| 69 | + hidden_states: torch.Tensor, |
| 70 | + kv_cache: KVCache, |
| 71 | + input_metadata: InputMetadata, |
| 72 | + cache_event: Optional[torch.cuda.Event], |
| 73 | + ) -> torch.Tensor: |
| 74 | + qkv, _ = self.c_attn(hidden_states) |
| 75 | + q, k, v = qkv.chunk(chunks=3, dim=-1) |
| 76 | + key_cache, value_cache = kv_cache |
| 77 | + attn_output = self.attn( |
| 78 | + q, k, v, key_cache, value_cache, input_metadata, cache_event) |
| 79 | + attn_output, _ = self.c_proj(attn_output) |
| 80 | + return attn_output |
| 81 | + |
| 82 | + |
| 83 | +class GPTBigMLP(nn.Module): |
| 84 | + |
| 85 | + def __init__( |
| 86 | + self, |
| 87 | + intermediate_size: int, |
| 88 | + config: GPTBigCodeConfig, |
| 89 | + ): |
| 90 | + super().__init__() |
| 91 | + hidden_size = config.hidden_size |
| 92 | + self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, |
| 93 | + bias=True, gather_output=False, |
| 94 | + perform_initialization=False) |
| 95 | + self.c_proj = RowParallelLinear(intermediate_size, hidden_size, |
| 96 | + bias=True, input_is_parallel=True, |
| 97 | + perform_initialization=False) |
| 98 | + self.act = get_act_fn(config.activation_function) |
| 99 | + |
| 100 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 101 | + hidden_states, _ = self.c_fc(hidden_states) |
| 102 | + hidden_states = self.act(hidden_states) |
| 103 | + hidden_states, _ = self.c_proj(hidden_states) |
| 104 | + return hidden_states |
| 105 | + |
| 106 | + |
| 107 | +class GPTBigCodeBlock(nn.Module): |
| 108 | + |
| 109 | + def __init__(self, config: GPTBigCodeConfig): |
| 110 | + super().__init__() |
| 111 | + hidden_size = config.hidden_size |
| 112 | + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size |
| 113 | + |
| 114 | + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| 115 | + self.attn = GPTBigCodeAttention(config) |
| 116 | + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| 117 | + self.mlp = GPTBigMLP(inner_dim, config) |
| 118 | + |
| 119 | + def forward( |
| 120 | + self, |
| 121 | + hidden_states: torch.Tensor, |
| 122 | + kv_cache: KVCache, |
| 123 | + input_metadata: InputMetadata, |
| 124 | + cache_event: Optional[torch.cuda.Event], |
| 125 | + ) -> torch.Tensor: |
| 126 | + residual = hidden_states |
| 127 | + hidden_states = self.ln_1(hidden_states) |
| 128 | + attn_output = self.attn( |
| 129 | + hidden_states=hidden_states, |
| 130 | + kv_cache=kv_cache, |
| 131 | + input_metadata=input_metadata, |
| 132 | + cache_event=cache_event, |
| 133 | + ) |
| 134 | + # residual connection |
| 135 | + hidden_states = attn_output + residual |
| 136 | + |
| 137 | + residual = hidden_states |
| 138 | + hidden_states = self.ln_2(hidden_states) |
| 139 | + feed_forward_hidden_states = self.mlp(hidden_states) |
| 140 | + # residual connection |
| 141 | + hidden_states = residual + feed_forward_hidden_states |
| 142 | + return hidden_states |
| 143 | + |
| 144 | + |
| 145 | +class GPTBigCodeModel(nn.Module): |
| 146 | + |
| 147 | + def __init__(self, config: GPTBigCodeConfig): |
| 148 | + super().__init__() |
| 149 | + self.config = config |
| 150 | + assert config.add_cross_attention == False |
| 151 | + |
| 152 | + self.embed_dim = config.hidden_size |
| 153 | + |
| 154 | + # Optimization: While the vocab size of GPT-2 is 50257, we extend it |
| 155 | + # to 50304 in order to make it divisible by 64. |
| 156 | + # This improves performance since GPUs are faster if the dimension |
| 157 | + # is divisible by 64. In addition, it allows us to shard the embedding |
| 158 | + # layer across 2, 4, 8, or more GPUs. |
| 159 | + vocab_size = ((config.vocab_size + 63) // 64) * 64 |
| 160 | + self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) |
| 161 | + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) |
| 162 | + self.h = nn.ModuleList( |
| 163 | + [GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)]) |
| 164 | + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
| 165 | + |
| 166 | + def forward( |
| 167 | + self, |
| 168 | + input_ids: torch.Tensor, |
| 169 | + position_ids: torch.Tensor, |
| 170 | + kv_caches: List[KVCache], |
| 171 | + input_metadata: InputMetadata, |
| 172 | + cache_events: Optional[List[torch.cuda.Event]], |
| 173 | + ) -> torch.Tensor: |
| 174 | + inputs_embeds = self.wte(input_ids) |
| 175 | + position_embeds = self.wpe(position_ids) |
| 176 | + hidden_states = inputs_embeds + position_embeds |
| 177 | + |
| 178 | + for i in range(len(self.h)): |
| 179 | + if cache_events is None: |
| 180 | + cache_event = None |
| 181 | + else: |
| 182 | + cache_event = cache_events[i] |
| 183 | + layer = self.h[i] |
| 184 | + hidden_states = layer( |
| 185 | + hidden_states, kv_caches[i], input_metadata, cache_event) |
| 186 | + |
| 187 | + hidden_states = self.ln_f(hidden_states) |
| 188 | + return hidden_states |
| 189 | + |
| 190 | + |
| 191 | +class GPTBigCodeForCausalLM(nn.Module): |
| 192 | + |
| 193 | + def __init__(self, config: GPTBigCodeConfig): |
| 194 | + super().__init__() |
| 195 | + self.config = config |
| 196 | + self.transformer = GPTBigCodeModel(config) |
| 197 | + # TODO(zhuohan): create a new weight after implementing pipeline |
| 198 | + # parallelism |
| 199 | + self.lm_head_weight = self.transformer.wte.weight |
| 200 | + self.sampler = Sampler(config.vocab_size) |
| 201 | + |
| 202 | + def forward( |
| 203 | + self, |
| 204 | + input_ids: torch.Tensor, |
| 205 | + positions: torch.Tensor, |
| 206 | + kv_caches: List[KVCache], |
| 207 | + input_metadata: InputMetadata, |
| 208 | + cache_events: Optional[List[torch.cuda.Event]], |
| 209 | + ) -> Dict[int, SequenceOutputs]: |
| 210 | + hidden_states = self.transformer( |
| 211 | + input_ids, positions, kv_caches, input_metadata, cache_events) |
| 212 | + next_tokens = self.sampler( |
| 213 | + self.lm_head_weight, hidden_states, input_metadata) |
| 214 | + return next_tokens |
| 215 | + |
| 216 | + _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] |
| 217 | + _row_parallel_weights = ["c_proj.weight"] |
| 218 | + |
| 219 | + def load_weights(self, model_name_or_path: str, |
| 220 | + cache_dir: Optional[str] = None, |
| 221 | + use_np_cache: bool = False): |
| 222 | + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() |
| 223 | + tensor_model_parallel_rank = get_tensor_model_parallel_rank() |
| 224 | + state_dict = self.state_dict() |
| 225 | + |
| 226 | + for name, loaded_weight in hf_model_weights_iterator( |
| 227 | + model_name_or_path, cache_dir, use_np_cache): |
| 228 | + if "lm_head.weight" in name: |
| 229 | + # GPT-2 ties the weights of the embedding layer and the final |
| 230 | + # linear layer. |
| 231 | + continue |
| 232 | + if ".attn.bias" in name: |
| 233 | + # Skip attention mask. |
| 234 | + # NOTE: "c_attn.bias" should not be skipped. |
| 235 | + continue |
| 236 | + |
| 237 | + param = state_dict[name] |
| 238 | + |
| 239 | + def _expand_mqa_mha(qkv_array, n_head, head_dim): |
| 240 | + """manipulates along axis=0 from MQA to MHA |
| 241 | + inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim) |
| 242 | + with n_heads for q, then 1 for k, 1 for 1 v, times head dim |
| 243 | + return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim) |
| 244 | + |
| 245 | + TODO: this function is no longer needed once vllm supports MQA. |
| 246 | + """ |
| 247 | + qkv_array = qkv_array.numpy() |
| 248 | + |
| 249 | + dims_q = n_head * head_dim |
| 250 | + q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0) |
| 251 | + # q is fine, but k & v have not replicated shape along the first axis |
| 252 | + # as long as MQA is not nativly supported, increase memory and replicated |
| 253 | + # (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim) |
| 254 | + if k.ndim == 2 and v.ndim == 2: |
| 255 | + replication = (n_head, 1) # weights |
| 256 | + else: |
| 257 | + replication = n_head # biases |
| 258 | + # replicate n_head times for q, v |
| 259 | + k, v = np.tile(k, replication), np.tile(v, replication) |
| 260 | + # concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) |
| 261 | + # to (3 * n_heads * head_dim, hidden_dim) |
| 262 | + qkv_array = np.concatenate((q, k, v), axis=0) |
| 263 | + return torch.from_numpy(qkv_array) |
| 264 | + |
| 265 | + # For the fused QKV linear layer, manually shard the weights. |
| 266 | + if "c_attn" in name: |
| 267 | + # GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. |
| 268 | + # When tensor parallelism is used, we shard the weights along the head dimension. |
| 269 | + total_num_heads = self.config.num_attention_heads |
| 270 | + hidden_size = self.config.hidden_size |
| 271 | + head_size = hidden_size // total_num_heads |
| 272 | + num_heads = total_num_heads // tensor_model_parallel_world_size |
| 273 | + head_start = tensor_model_parallel_rank * num_heads |
| 274 | + head_end = (tensor_model_parallel_rank + 1) * num_heads |
| 275 | + |
| 276 | + if name.endswith(".weight"): |
| 277 | + loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) |
| 278 | + loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) |
| 279 | + loaded_weight = loaded_weight[:, head_start:head_end, :, :] |
| 280 | + loaded_weight = loaded_weight.reshape(-1, hidden_size) |
| 281 | + elif name.endswith(".bias"): |
| 282 | + loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) |
| 283 | + loaded_weight = loaded_weight.view(3, total_num_heads, head_size) |
| 284 | + loaded_weight = loaded_weight[:, head_start:head_end, :] |
| 285 | + loaded_weight = loaded_weight.reshape(-1) |
| 286 | + else: |
| 287 | + raise ValueError(f"Unexpected parameter name {name}") |
| 288 | + load_tensor_parallel_weights(param, loaded_weight, name, |
| 289 | + self._column_parallel_weights, |
| 290 | + self._row_parallel_weights, |
| 291 | + tensor_model_parallel_rank) |
0 commit comments