|
| 1 | +# coding=utf-8 |
| 2 | +# Adapted from |
| 3 | +# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py |
| 4 | +# Copyright 2023 The vLLM team. |
| 5 | +# Copyright (c) Microsoft Corporation. |
| 6 | +# Licensed under the MIT license. |
| 7 | +# |
| 8 | +# BSD 3-Clause License |
| 9 | +# |
| 10 | +# Copyright (c) 2022, Tri Dao, [email protected]. |
| 11 | +# All rights reserved. |
| 12 | +# |
| 13 | +# Redistribution and use in source and binary forms, with or without |
| 14 | +# modification, are permitted provided that the following conditions are met: |
| 15 | +# |
| 16 | +# * Redistributions of source code must retain the above copyright notice, this |
| 17 | +# list of conditions and the following disclaimer. |
| 18 | +# |
| 19 | +# * Redistributions in binary form must reproduce the above copyright notice, |
| 20 | +# this list of conditions and the following disclaimer in the documentation |
| 21 | +# and/or other materials provided with the distribution. |
| 22 | +# |
| 23 | +# * Neither the name of the copyright holder nor the names of its |
| 24 | +# contributors may be used to endorse or promote products derived from |
| 25 | +# this software without specific prior written permission. |
| 26 | +# |
| 27 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 28 | +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 29 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 30 | +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| 31 | +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| 32 | +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| 33 | +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| 34 | +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| 35 | +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 36 | +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 37 | +"""Inference-only Phi-1.5 model compatible with HuggingFace weights. |
| 38 | +
|
| 39 | +The input of the model is flattened to a 1D tensor of tokens. The model uses |
| 40 | +InputMetadata to extract the original 2D shape of the input. |
| 41 | +""" |
| 42 | +from typing import List, Optional, Tuple |
| 43 | + |
| 44 | +import torch |
| 45 | +from torch import nn |
| 46 | +from transformers import PretrainedConfig |
| 47 | + |
| 48 | +from vllm.model_executor.input_metadata import InputMetadata |
| 49 | +from vllm.model_executor.layers.activation import get_act_fn |
| 50 | +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE |
| 51 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 52 | + LinearMethodBase, |
| 53 | + QKVParallelLinear, |
| 54 | + RowParallelLinear) |
| 55 | +from vllm.model_executor.layers.sampler import Sampler |
| 56 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 57 | + VocabParallelEmbedding, ParallelLMHead) |
| 58 | +from vllm.model_executor.parallel_utils.parallel_state import ( |
| 59 | + get_tensor_model_parallel_world_size) |
| 60 | +from vllm.model_executor.weight_utils import (default_weight_loader, |
| 61 | + hf_model_weights_iterator) |
| 62 | +from vllm.sequence import SamplerOutput |
| 63 | + |
| 64 | +KVCache = Tuple[torch.Tensor, torch.Tensor] |
| 65 | + |
| 66 | + |
| 67 | +class PhiEmbedding(nn.Module): |
| 68 | + |
| 69 | + def __init__(self, config: PretrainedConfig): |
| 70 | + super().__init__() |
| 71 | + |
| 72 | + self.wte = VocabParallelEmbedding( |
| 73 | + config.vocab_size, |
| 74 | + config.hidden_size, |
| 75 | + ) |
| 76 | + |
| 77 | + def forward(self, input_ids: torch.LongTensor): |
| 78 | + return self.wte(input_ids) |
| 79 | + |
| 80 | + |
| 81 | +class PhiAttention(nn.Module): |
| 82 | + |
| 83 | + def __init__(self, |
| 84 | + config: PretrainedConfig, |
| 85 | + linear_method: Optional[LinearMethodBase] = None): |
| 86 | + super().__init__() |
| 87 | + self.total_num_heads = config.num_attention_heads |
| 88 | + self.hidden_size = config.hidden_size |
| 89 | + self.head_size = self.hidden_size // self.total_num_heads |
| 90 | + |
| 91 | + tensor_model_parallel_world_size = ( |
| 92 | + get_tensor_model_parallel_world_size()) |
| 93 | + assert self.total_num_heads % tensor_model_parallel_world_size == 0 |
| 94 | + self.num_heads = (self.total_num_heads // |
| 95 | + tensor_model_parallel_world_size) |
| 96 | + |
| 97 | + # pylint: disable=C0103 |
| 98 | + self.Wqkv = QKVParallelLinear( |
| 99 | + self.hidden_size, |
| 100 | + self.head_size, |
| 101 | + self.total_num_heads, |
| 102 | + linear_method=linear_method, |
| 103 | + ) |
| 104 | + self.qkv_proj = QKVParallelLinear( |
| 105 | + config.hidden_size, |
| 106 | + self.head_size, |
| 107 | + self.total_num_heads, |
| 108 | + bias=False, |
| 109 | + linear_method=linear_method, |
| 110 | + ) |
| 111 | + self.out_proj = RowParallelLinear( |
| 112 | + self.hidden_size, |
| 113 | + self.hidden_size, |
| 114 | + linear_method=linear_method, |
| 115 | + ) |
| 116 | + |
| 117 | + scaling = self.head_size**-0.5 |
| 118 | + rotary_dim = config.rotary_dim |
| 119 | + assert rotary_dim % 2 == 0 |
| 120 | + |
| 121 | + # pylint: disable=C0301 |
| 122 | + # Refer to: |
| 123 | + # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 |
| 124 | + rope_theta = 10000 |
| 125 | + max_position_embeddings = getattr(config, "n_positions", 2048) |
| 126 | + self.attn = PagedAttentionWithRoPE( |
| 127 | + self.num_heads, |
| 128 | + self.head_size, |
| 129 | + scaling, |
| 130 | + rotary_dim, |
| 131 | + base=rope_theta, |
| 132 | + max_position=max_position_embeddings) |
| 133 | + |
| 134 | + def forward( |
| 135 | + self, |
| 136 | + position_ids: torch.Tensor, |
| 137 | + hidden_states: torch.Tensor, |
| 138 | + kv_cache: KVCache, |
| 139 | + input_metadata: InputMetadata, |
| 140 | + cache_event: Optional[torch.cuda.Event], |
| 141 | + ) -> torch.Tensor: |
| 142 | + qkv, _ = self.Wqkv(hidden_states) |
| 143 | + q, k, v = qkv.chunk(chunks=3, dim=-1) |
| 144 | + k_cache, v_cache = kv_cache |
| 145 | + attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, |
| 146 | + input_metadata, cache_event) |
| 147 | + output, _ = self.out_proj(attn_output) |
| 148 | + return output |
| 149 | + |
| 150 | + |
| 151 | +class PhiMLP(nn.Module): |
| 152 | + |
| 153 | + def __init__(self, |
| 154 | + config: PretrainedConfig, |
| 155 | + linear_method: Optional[LinearMethodBase] = None): |
| 156 | + super().__init__() |
| 157 | + |
| 158 | + n_inner = getattr(config, "n_inner", None) |
| 159 | + n_inner = n_inner if n_inner is not None else 4 * config.hidden_size |
| 160 | + |
| 161 | + self.fc1 = ColumnParallelLinear( |
| 162 | + config.hidden_size, |
| 163 | + n_inner, |
| 164 | + linear_method=linear_method, |
| 165 | + ) |
| 166 | + self.fc2 = RowParallelLinear( |
| 167 | + n_inner, |
| 168 | + config.hidden_size, |
| 169 | + linear_method=linear_method, |
| 170 | + ) |
| 171 | + self.act = get_act_fn(config.activation_function) |
| 172 | + |
| 173 | + def forward(self, hidden_states): |
| 174 | + hidden_states, _ = self.fc1(hidden_states) |
| 175 | + hidden_states = self.act(hidden_states) |
| 176 | + hidden_states, _ = self.fc2(hidden_states) |
| 177 | + return hidden_states |
| 178 | + |
| 179 | + |
| 180 | +class PhiLayer(nn.Module): |
| 181 | + |
| 182 | + def __init__(self, |
| 183 | + config: PretrainedConfig, |
| 184 | + linear_method: Optional[LinearMethodBase] = None): |
| 185 | + super().__init__() |
| 186 | + self.ln = nn.LayerNorm(config.hidden_size, |
| 187 | + eps=config.layer_norm_epsilon) |
| 188 | + self.mixer = PhiAttention(config, linear_method) |
| 189 | + self.mlp = PhiMLP(config, linear_method) |
| 190 | + |
| 191 | + def forward( |
| 192 | + self, |
| 193 | + position_ids: torch.Tensor, |
| 194 | + hidden_states: torch.Tensor, |
| 195 | + kv_cache: KVCache, |
| 196 | + input_metadata: InputMetadata, |
| 197 | + cache_event: Optional[torch.cuda.Event], |
| 198 | + ) -> torch.Tensor: |
| 199 | + residual = hidden_states |
| 200 | + hidden_states = self.ln(hidden_states) |
| 201 | + attn_outputs = self.mixer( |
| 202 | + position_ids=position_ids, |
| 203 | + hidden_states=hidden_states, |
| 204 | + kv_cache=kv_cache, |
| 205 | + input_metadata=input_metadata, |
| 206 | + cache_event=cache_event, |
| 207 | + ) |
| 208 | + feed_forward_hidden_states = self.mlp(hidden_states) |
| 209 | + hidden_states = attn_outputs + feed_forward_hidden_states + residual |
| 210 | + return hidden_states |
| 211 | + |
| 212 | + |
| 213 | +class PhiCausalLMHead(nn.Module): |
| 214 | + |
| 215 | + def __init__(self, config: PretrainedConfig): |
| 216 | + super().__init__() |
| 217 | + self.ln = nn.LayerNorm(config.hidden_size, |
| 218 | + eps=config.layer_norm_epsilon) |
| 219 | + self.linear = ParallelLMHead(config.vocab_size, |
| 220 | + config.hidden_size, |
| 221 | + bias=True) |
| 222 | + self.sampler = Sampler(config.vocab_size) |
| 223 | + |
| 224 | + def forward( |
| 225 | + self, |
| 226 | + hidden_states: torch.Tensor, |
| 227 | + input_metadata: InputMetadata, |
| 228 | + ): |
| 229 | + hidden_states = self.ln(hidden_states) |
| 230 | + next_tokens = self.sampler(self.linear.weight, hidden_states, |
| 231 | + input_metadata, self.linear.bias) |
| 232 | + return next_tokens |
| 233 | + |
| 234 | + |
| 235 | +class PhiModel(nn.Module): |
| 236 | + |
| 237 | + def __init__(self, |
| 238 | + config: PretrainedConfig, |
| 239 | + linear_method: Optional[LinearMethodBase] = None): |
| 240 | + super().__init__() |
| 241 | + self.config = config |
| 242 | + self.linear_method = linear_method |
| 243 | + self.embd = PhiEmbedding(config) |
| 244 | + self.h = nn.ModuleList([ |
| 245 | + PhiLayer(config, linear_method) |
| 246 | + for _ in range(config.num_hidden_layers) |
| 247 | + ]) |
| 248 | + |
| 249 | + def forward( |
| 250 | + self, |
| 251 | + input_ids: torch.Tensor, |
| 252 | + positions: torch.Tensor, |
| 253 | + kv_caches: List[KVCache], |
| 254 | + input_metadata: InputMetadata, |
| 255 | + cache_events: Optional[List[torch.cuda.Event]], |
| 256 | + ) -> SamplerOutput: |
| 257 | + hidden_states = self.embd(input_ids) |
| 258 | + for i in range(self.config.num_hidden_layers): |
| 259 | + if cache_events is None: |
| 260 | + cache_event = None |
| 261 | + else: |
| 262 | + cache_event = cache_events[i] |
| 263 | + layer = self.h[i] |
| 264 | + hidden_states = layer( |
| 265 | + positions, |
| 266 | + hidden_states, |
| 267 | + kv_caches[i], |
| 268 | + input_metadata, |
| 269 | + cache_event, |
| 270 | + ) |
| 271 | + return hidden_states |
| 272 | + |
| 273 | + |
| 274 | +class PhiForCausalLM(nn.Module): |
| 275 | + |
| 276 | + def __init__(self, |
| 277 | + config: PretrainedConfig, |
| 278 | + linear_method: Optional[LinearMethodBase] = None): |
| 279 | + super().__init__() |
| 280 | + self.config = config |
| 281 | + self.linear_method = linear_method |
| 282 | + |
| 283 | + self.transformer = PhiModel(config, linear_method) |
| 284 | + self.lm_head = PhiCausalLMHead(config) |
| 285 | + |
| 286 | + def forward( |
| 287 | + self, |
| 288 | + input_ids: torch.Tensor, |
| 289 | + positions: torch.Tensor, |
| 290 | + kv_caches: List[KVCache], |
| 291 | + input_metadata: InputMetadata, |
| 292 | + cache_events: Optional[List[torch.cuda.Event]], |
| 293 | + ) -> SamplerOutput: |
| 294 | + hidden_states = self.transformer(input_ids, positions, kv_caches, |
| 295 | + input_metadata, cache_events) |
| 296 | + lm_logits = self.lm_head(hidden_states, input_metadata) |
| 297 | + return lm_logits |
| 298 | + |
| 299 | + def load_weights(self, |
| 300 | + model_name_or_path: str, |
| 301 | + cache_dir: Optional[str] = None, |
| 302 | + load_format: str = "auto", |
| 303 | + revision: Optional[str] = None): |
| 304 | + params_dict = dict(self.named_parameters()) |
| 305 | + for name, loaded_weight in hf_model_weights_iterator( |
| 306 | + model_name_or_path, cache_dir, load_format, revision): |
| 307 | + if "rotary_emb.inv_freq" in name: |
| 308 | + continue |
| 309 | + |
| 310 | + # pylint: disable=E1136 |
| 311 | + param = params_dict[name] |
| 312 | + weight_loader = getattr(param, "weight_loader", |
| 313 | + default_weight_loader) |
| 314 | + weight_loader(param, loaded_weight) |
0 commit comments