|
| 1 | +# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main |
| 2 | +import math |
| 3 | +from typing import Dict, List, Optional, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | + |
| 8 | +from vllm.model_executor.input_metadata import InputMetadata |
| 9 | +from vllm.model_executor.layers.activation import get_act_fn |
| 10 | +from vllm.model_executor.layers.attention import PagedAttentionWithALiBi |
| 11 | +from vllm.model_executor.layers.sampler import Sampler |
| 12 | +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, |
| 13 | + load_tensor_parallel_weights) |
| 14 | +from vllm.model_executor.parallel_utils.parallel_state import ( |
| 15 | + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) |
| 16 | +from vllm.model_executor.parallel_utils.tensor_parallel import ( |
| 17 | + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) |
| 18 | +from vllm.sequence import SequenceOutputs |
| 19 | +from vllm.transformers_utils.configs.mpt import MPTConfig |
| 20 | + |
| 21 | +KVCache = Tuple[torch.Tensor, torch.Tensor] |
| 22 | + |
| 23 | + |
| 24 | +def _get_alibi_slopes( |
| 25 | + total_num_heads: int, |
| 26 | + alibi_bias_max: int, |
| 27 | +) -> torch.Tensor: |
| 28 | + next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) |
| 29 | + m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) |
| 30 | + m = m.mul(alibi_bias_max / next_power_of_2) |
| 31 | + slopes = 1.0 / torch.pow(2, m) |
| 32 | + if next_power_of_2 != total_num_heads: |
| 33 | + slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads] |
| 34 | + return slopes |
| 35 | + |
| 36 | + |
| 37 | +class MPTAttention(nn.Module): |
| 38 | + |
| 39 | + def __init__(self, config: MPTConfig): |
| 40 | + super().__init__() |
| 41 | + self.d_model = config.d_model |
| 42 | + self.total_num_heads = config.n_heads |
| 43 | + self.clip_qkv = config.attn_config["clip_qkv"] |
| 44 | + self.qk_ln = config.attn_config["qk_ln"] |
| 45 | + self.alibi_bias_max = config.attn_config["alibi_bias_max"] |
| 46 | + assert not config.attn_config["prefix_lm"] |
| 47 | + assert config.attn_config["alibi"] |
| 48 | + |
| 49 | + self.qkv_proj = ColumnParallelLinear( |
| 50 | + self.d_model, |
| 51 | + 3 * self.d_model, |
| 52 | + bias=not config.no_bias, |
| 53 | + gather_output=False, |
| 54 | + perform_initialization=False, |
| 55 | + ) |
| 56 | + if self.qk_ln: |
| 57 | + self.q_ln = nn.LayerNorm(self.d_model) |
| 58 | + self.k_ln = nn.LayerNorm(self.d_model) |
| 59 | + self.out_proj = RowParallelLinear( |
| 60 | + self.d_model, |
| 61 | + self.d_model, |
| 62 | + bias=not config.no_bias, |
| 63 | + input_is_parallel=True, |
| 64 | + perform_initialization=False, |
| 65 | + ) |
| 66 | + |
| 67 | + tp_world_size = get_tensor_model_parallel_world_size() |
| 68 | + assert self.total_num_heads % tp_world_size == 0 |
| 69 | + self.num_heads = self.total_num_heads // tp_world_size |
| 70 | + |
| 71 | + # Create the alibi slopes and slice them. |
| 72 | + tp_rank = get_tensor_model_parallel_rank() |
| 73 | + head_start = tp_rank * self.num_heads |
| 74 | + head_end = (tp_rank + 1) * self.num_heads |
| 75 | + alibi_slopes = _get_alibi_slopes(self.total_num_heads, |
| 76 | + self.alibi_bias_max) |
| 77 | + alibi_slopes = alibi_slopes[head_start:head_end].tolist() |
| 78 | + |
| 79 | + self.head_dim = self.d_model // self.total_num_heads |
| 80 | + scaling = self.head_dim**-0.5 |
| 81 | + self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, |
| 82 | + scaling, alibi_slopes) |
| 83 | + |
| 84 | + def forward( |
| 85 | + self, |
| 86 | + position_ids: torch.Tensor, |
| 87 | + hidden_states: torch.Tensor, |
| 88 | + kv_cache: KVCache, |
| 89 | + input_metadata: InputMetadata, |
| 90 | + cache_event: Optional[torch.cuda.Event], |
| 91 | + ) -> torch.Tensor: |
| 92 | + del position_ids # unused. |
| 93 | + qkv, _ = self.qkv_proj(hidden_states) |
| 94 | + if self.clip_qkv is not None: |
| 95 | + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) |
| 96 | + q, k, v = qkv.chunk(chunks=3, dim=-1) |
| 97 | + if self.qk_ln: |
| 98 | + q = self.q_ln(q) |
| 99 | + k = self.k_ln(k) |
| 100 | + k_cache, v_cache = kv_cache |
| 101 | + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, |
| 102 | + cache_event) |
| 103 | + output, _ = self.out_proj(attn_output) |
| 104 | + return output |
| 105 | + |
| 106 | + |
| 107 | +class MPTMLP(nn.Module): |
| 108 | + |
| 109 | + def __init__(self, config: MPTConfig): |
| 110 | + super().__init__() |
| 111 | + hidden_size = config.d_model |
| 112 | + expansion_ratio = config.expansion_ratio |
| 113 | + intermediate_size = expansion_ratio * hidden_size |
| 114 | + self.up_proj = ColumnParallelLinear(hidden_size, |
| 115 | + intermediate_size, |
| 116 | + bias=not config.no_bias, |
| 117 | + gather_output=False, |
| 118 | + perform_initialization=False) |
| 119 | + self.act = get_act_fn("gelu") |
| 120 | + self.down_proj = RowParallelLinear(intermediate_size, |
| 121 | + hidden_size, |
| 122 | + bias=not config.no_bias, |
| 123 | + input_is_parallel=True, |
| 124 | + perform_initialization=False) |
| 125 | + |
| 126 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 127 | + x, _ = self.up_proj(x) |
| 128 | + x = self.act(x) |
| 129 | + x, _ = self.down_proj(x) |
| 130 | + return x |
| 131 | + |
| 132 | + |
| 133 | +class MPTBlock(nn.Module): |
| 134 | + |
| 135 | + def __init__(self, config: MPTConfig): |
| 136 | + super().__init__() |
| 137 | + hidden_size = config.d_model |
| 138 | + self.norm_1 = nn.LayerNorm(hidden_size) |
| 139 | + self.attn = MPTAttention(config) |
| 140 | + self.norm_2 = nn.LayerNorm(hidden_size) |
| 141 | + self.ffn = MPTMLP(config) |
| 142 | + |
| 143 | + def forward( |
| 144 | + self, |
| 145 | + position_ids: torch.Tensor, |
| 146 | + hidden_states: torch.Tensor, |
| 147 | + kv_cache: KVCache, |
| 148 | + input_metadata: InputMetadata, |
| 149 | + cache_event: Optional[torch.cuda.Event], |
| 150 | + ) -> torch.Tensor: |
| 151 | + x = self.norm_1(hidden_states) |
| 152 | + x = self.attn( |
| 153 | + position_ids=position_ids, |
| 154 | + hidden_states=x, |
| 155 | + kv_cache=kv_cache, |
| 156 | + input_metadata=input_metadata, |
| 157 | + cache_event=cache_event, |
| 158 | + ) |
| 159 | + hidden_states = hidden_states + x |
| 160 | + x = self.norm_2(hidden_states) |
| 161 | + x = self.ffn(x) |
| 162 | + hidden_states = hidden_states + x |
| 163 | + return hidden_states |
| 164 | + |
| 165 | + |
| 166 | +class MPTModel(nn.Module): |
| 167 | + |
| 168 | + def __init__(self, config: MPTConfig): |
| 169 | + super().__init__() |
| 170 | + assert config.embedding_fraction == 1.0 |
| 171 | + assert config.norm_type == "low_precision_layernorm" |
| 172 | + |
| 173 | + self.wte = VocabParallelEmbedding(config.vocab_size, |
| 174 | + config.d_model, |
| 175 | + perform_initialization=False) |
| 176 | + self.blocks = nn.ModuleList( |
| 177 | + [MPTBlock(config) for _ in range(config.n_layers)]) |
| 178 | + self.norm_f = nn.LayerNorm(config.d_model) |
| 179 | + if config.no_bias: |
| 180 | + for module in self.modules(): |
| 181 | + if hasattr(module, "bias"): |
| 182 | + if isinstance(module.bias, nn.Parameter): |
| 183 | + # Remove the bias term in Linear and LayerNorm. |
| 184 | + module.register_parameter("bias", None) |
| 185 | + |
| 186 | + def forward( |
| 187 | + self, |
| 188 | + input_ids: torch.Tensor, |
| 189 | + position_ids: torch.Tensor, |
| 190 | + kv_caches: List[KVCache], |
| 191 | + input_metadata: InputMetadata, |
| 192 | + cache_events: Optional[List[torch.cuda.Event]], |
| 193 | + ) -> torch.Tensor: |
| 194 | + hidden_states = self.wte(input_ids) |
| 195 | + for i in range(len(self.blocks)): |
| 196 | + if cache_events is None: |
| 197 | + cache_event = None |
| 198 | + else: |
| 199 | + cache_event = cache_events[i] |
| 200 | + block = self.blocks[i] |
| 201 | + hidden_states = block( |
| 202 | + position_ids, |
| 203 | + hidden_states, |
| 204 | + kv_caches[i], |
| 205 | + input_metadata, |
| 206 | + cache_event, |
| 207 | + ) |
| 208 | + hidden_states = self.norm_f(hidden_states) |
| 209 | + return hidden_states |
| 210 | + |
| 211 | + |
| 212 | +class MPTForCausalLM(nn.Module): |
| 213 | + |
| 214 | + def __init__(self, config: MPTConfig): |
| 215 | + super().__init__() |
| 216 | + self.config = config |
| 217 | + assert config.tie_word_embeddings |
| 218 | + |
| 219 | + self.transformer = MPTModel(config) |
| 220 | + # TODO(zhuohan): create a new weight after implementing pipeline |
| 221 | + # parallelism |
| 222 | + self.lm_head_weight = self.transformer.wte.weight |
| 223 | + self.sampler = Sampler(config.vocab_size) |
| 224 | + |
| 225 | + def forward( |
| 226 | + self, |
| 227 | + input_ids: torch.Tensor, |
| 228 | + positions: torch.Tensor, |
| 229 | + kv_caches: List[KVCache], |
| 230 | + input_metadata: InputMetadata, |
| 231 | + cache_events: Optional[List[torch.cuda.Event]], |
| 232 | + ) -> Dict[int, SequenceOutputs]: |
| 233 | + hidden_states = self.transformer(input_ids, positions, kv_caches, |
| 234 | + input_metadata, cache_events) |
| 235 | + next_tokens = self.sampler(self.lm_head_weight, hidden_states, |
| 236 | + input_metadata) |
| 237 | + return next_tokens |
| 238 | + |
| 239 | + _column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"] |
| 240 | + _row_parallel_weights = ["out_proj.weight", "down_proj.weight"] |
| 241 | + |
| 242 | + def load_weights(self, |
| 243 | + model_name_or_path: str, |
| 244 | + cache_dir: Optional[str] = None, |
| 245 | + use_np_cache: bool = False): |
| 246 | + tp_world_size = get_tensor_model_parallel_world_size() |
| 247 | + tp_rank = get_tensor_model_parallel_rank() |
| 248 | + state_dict = self.state_dict() |
| 249 | + for name, loaded_weight in hf_model_weights_iterator( |
| 250 | + model_name_or_path, cache_dir, use_np_cache): |
| 251 | + if "Wqkv" in name: |
| 252 | + # NOTE(woosuk): MPT's fused QKV has the shape of |
| 253 | + # [3 * num_heads * head_size, hidden_size]. |
| 254 | + # When tensor model parallelism is used, we need to shard |
| 255 | + # the weight along the hidden dimension. |
| 256 | + total_num_heads = self.config.num_attention_heads |
| 257 | + hidden_size = self.config.hidden_size |
| 258 | + head_size = hidden_size // total_num_heads |
| 259 | + num_heads = total_num_heads // tp_world_size |
| 260 | + head_start = tp_rank * num_heads |
| 261 | + head_end = (tp_rank + 1) * num_heads |
| 262 | + |
| 263 | + if name.endswith(".weight"): |
| 264 | + loaded_weight = loaded_weight.view(3, total_num_heads, |
| 265 | + head_size, hidden_size) |
| 266 | + loaded_weight = loaded_weight[:, head_start:head_end, :, :] |
| 267 | + loaded_weight = loaded_weight.reshape(-1, hidden_size) |
| 268 | + elif name.endswith(".bias"): |
| 269 | + loaded_weight = loaded_weight.view(3, total_num_heads, |
| 270 | + head_size) |
| 271 | + loaded_weight = loaded_weight[:, head_start:head_end, :] |
| 272 | + loaded_weight = loaded_weight.reshape(-1) |
| 273 | + else: |
| 274 | + raise ValueError(f"Unexpected parameter name {name}") |
| 275 | + name = name.replace("Wqkv", "qkv_proj") |
| 276 | + param = state_dict[name] |
| 277 | + load_tensor_parallel_weights(param, loaded_weight, name, |
| 278 | + self._column_parallel_weights, |
| 279 | + self._row_parallel_weights, tp_rank) |
0 commit comments