|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +from pathlib import Path |
| 7 | +from typing import Dict, Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | + |
| 12 | +from torch import Tensor |
| 13 | +from torch.nn import functional as F |
| 14 | +from torch.distributed._tensor import DTensor, Replicate |
| 15 | +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh |
| 16 | +from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel |
| 17 | + |
| 18 | +from build.utils import find_multiple |
| 19 | + |
| 20 | +from build.model import TransformerArgs, KVCache, apply_rotary_emb, precompute_freqs_cis |
| 21 | + |
| 22 | +config_path = Path(f"{str(Path(__file__).parent)}/known_model_params") |
| 23 | + |
| 24 | + |
| 25 | +# Use DTensor as output, by default |
| 26 | +Colwise = ColwiseParallel(use_local_output=False) |
| 27 | +Rowwise = RowwiseParallel(use_local_output=False) |
| 28 | + |
| 29 | +# Device mesh context |
| 30 | +device_mesh = None |
| 31 | + |
| 32 | + |
| 33 | +class Transformer(nn.Module): |
| 34 | + def __init__(self, config: TransformerArgs) -> None: |
| 35 | + super().__init__() |
| 36 | + self.config = config |
| 37 | + |
| 38 | + # Get device mesh |
| 39 | + global device_mesh |
| 40 | + if device_mesh is None: |
| 41 | + device_mesh = _mesh_resources.get_current_mesh() |
| 42 | + |
| 43 | + tok_embeddings = nn.Embedding(config.vocab_size, config.dim) |
| 44 | + self.tok_embeddings = parallelize_module( |
| 45 | + tok_embeddings, |
| 46 | + device_mesh, |
| 47 | + RowwiseParallel(input_layouts=Replicate()), |
| 48 | + ) |
| 49 | + self.layers = nn.ModuleList( |
| 50 | + TransformerBlock(config) for _ in range(config.n_layers) |
| 51 | + ) |
| 52 | + self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
| 53 | + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) |
| 54 | + |
| 55 | + # self.freqs_cis: Optional[Tensor] = None |
| 56 | + # self.mask_cache: Optional[Tensor] = None |
| 57 | + self.max_batch_size = -1 |
| 58 | + self.max_seq_length = -1 |
| 59 | + |
| 60 | + def setup_caches(self, max_batch_size, max_seq_length): |
| 61 | + if ( |
| 62 | + self.max_seq_length >= max_seq_length |
| 63 | + and self.max_batch_size >= max_batch_size |
| 64 | + ): |
| 65 | + return |
| 66 | + head_dim = self.config.dim // self.config.n_heads |
| 67 | + max_seq_length = find_multiple(max_seq_length, 8) |
| 68 | + self.max_seq_length = max_seq_length |
| 69 | + self.max_batch_size = max_batch_size |
| 70 | + for b in self.layers: |
| 71 | + b.attention.kv_cache = KVCache( |
| 72 | + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim |
| 73 | + ) |
| 74 | + |
| 75 | + freqs_cis = precompute_freqs_cis( |
| 76 | + self.config.dim // self.config.n_heads, |
| 77 | + self.config.block_size * 2, |
| 78 | + self.config.rope_base, |
| 79 | + use_scaled = self.config.use_scaled_rope, |
| 80 | + ) |
| 81 | + self.register_buffer("freqs_cis", freqs_cis, persistent=True) |
| 82 | + causal_mask = torch.tril( |
| 83 | + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) |
| 84 | + ) |
| 85 | + self.register_buffer("causal_mask", causal_mask, persistent=True) |
| 86 | + |
| 87 | + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: |
| 88 | + assert self.freqs_cis is not None, "Caches must be initialized first" |
| 89 | + mask = self.causal_mask[None, None, input_pos] |
| 90 | + freqs_cis = self.freqs_cis[input_pos] |
| 91 | + x: DTensor = self.tok_embeddings(idx) |
| 92 | + # TODO: sequence parallelize this |
| 93 | + |
| 94 | + for _, layer in enumerate(self.layers): |
| 95 | + x = layer(x, input_pos, freqs_cis, mask) |
| 96 | + x = self.norm(x) |
| 97 | + logits = self.output(x) |
| 98 | + # print(f"logits shape: {logits.shape}") |
| 99 | + return logits |
| 100 | + |
| 101 | + @classmethod |
| 102 | + def from_name(cls, name: str): |
| 103 | + return cls(TransformerArgs.from_name(name)) |
| 104 | + |
| 105 | + @classmethod |
| 106 | + def from_table(cls, name: str): |
| 107 | + return cls(TransformerArgs.from_table(name)) |
| 108 | + |
| 109 | + @classmethod |
| 110 | + def from_params(cls, params_path: str): |
| 111 | + return cls(TransformerArgs.from_params(params_path)) |
| 112 | + |
| 113 | + @classmethod |
| 114 | + def from_gguf(cls, gguf_path: str, **kwargs): |
| 115 | + from build.gguf_loader import load_model_and_state_dict |
| 116 | + |
| 117 | + model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) |
| 118 | + if state_dict != {}: |
| 119 | + model.load_state_dict(state_dict, assign=True) |
| 120 | + return model |
| 121 | + |
| 122 | + |
| 123 | +class TransformerBlock(nn.Module): |
| 124 | + def __init__(self, config: TransformerArgs) -> None: |
| 125 | + super().__init__() |
| 126 | + self.attention = Attention(config) |
| 127 | + self.feed_forward = FeedForward(config) |
| 128 | + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) |
| 129 | + self.attention_norm = RMSNorm(config.dim, config.norm_eps) |
| 130 | + |
| 131 | + def forward( |
| 132 | + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor |
| 133 | + ) -> Tensor: |
| 134 | + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) |
| 135 | + out = h + self.feed_forward(self.ffn_norm(h)) |
| 136 | + return out |
| 137 | + |
| 138 | + |
| 139 | +class Attention(nn.Module): |
| 140 | + def __init__(self, config: TransformerArgs): |
| 141 | + super().__init__() |
| 142 | + assert config.dim % config.n_heads == 0 |
| 143 | + |
| 144 | + # key, query, value projections for all heads, but in a batch |
| 145 | + # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim |
| 146 | + # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) |
| 147 | + wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) |
| 148 | + wk = nn.Linear( |
| 149 | + config.dim, config.n_local_heads * config.head_dim, bias=False |
| 150 | + ) |
| 151 | + wv = nn.Linear( |
| 152 | + config.dim, config.n_local_heads * config.head_dim, bias=False |
| 153 | + ) |
| 154 | + wo = nn.Linear(config.dim, config.dim, bias=False) |
| 155 | + |
| 156 | + self.wq = parallelize_module(wq, device_mesh, Colwise) |
| 157 | + self.wk = parallelize_module(wk, device_mesh, Colwise) |
| 158 | + self.wv = parallelize_module(wv, device_mesh, Colwise) |
| 159 | + self.wo = parallelize_module(wo, device_mesh, Rowwise) |
| 160 | + |
| 161 | + self.kv_cache = None |
| 162 | + |
| 163 | + self.n_heads = config.n_heads |
| 164 | + self.head_dim = config.head_dim |
| 165 | + self.n_local_heads = config.n_local_heads |
| 166 | + self.dim = config.dim |
| 167 | + self._register_load_state_dict_pre_hook(self.load_hook) |
| 168 | + |
| 169 | + def load_hook(self, state_dict, prefix, *args): |
| 170 | + # if prefix + "wq.weight" in state_dict: |
| 171 | + # wq = state_dict.pop(prefix + "wq.weight") |
| 172 | + # wk = state_dict.pop(prefix + "wk.weight") |
| 173 | + # wv = state_dict.pop(prefix + "wv.weight") |
| 174 | + # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) |
| 175 | + |
| 176 | + if prefix + "wqkv.weight" in state_dict: |
| 177 | + wqkv = state_dict.pop(prefix + "wqkv.weight") |
| 178 | + q_size = self.n_heads * self.head_dim |
| 179 | + kv_size = self.n_local_heads * self.head_dim |
| 180 | + wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0) |
| 181 | + state_dict[prefix + "wq.weight"] = wq |
| 182 | + state_dict[prefix + "wk.weight"] = wk |
| 183 | + state_dict[prefix + "wv.weight"] = wv |
| 184 | + |
| 185 | + return |
| 186 | + |
| 187 | + def _unfuse_wqkv_state_dict( |
| 188 | + state_dict: Dict[str, torch.Tensor], |
| 189 | + dim: int, |
| 190 | + ): |
| 191 | + for key in list(state_dict): |
| 192 | + if key.endswith("wqkv.weight"): |
| 193 | + tensor = state_dict[key] |
| 194 | + wq_key = key.replace("wqkv.weight", "wq.weight") |
| 195 | + state_dict[wq_key] = tensor[:dim] |
| 196 | + wk_key = key.replace("wqkv.weight", "wk.weight") |
| 197 | + wv_key = key.replace("wqkv.weight", "wv.weight") |
| 198 | + wk, wv = tensor[dim:].chunk(2, 0) |
| 199 | + state_dict[wk_key] = wk |
| 200 | + state_dict[wv_key] = wv |
| 201 | + state_dict.pop(key) |
| 202 | + else: |
| 203 | + continue |
| 204 | + |
| 205 | + _unfuse_wqkv_state_dict(state_dict, self.dim) |
| 206 | + |
| 207 | + def forward( |
| 208 | + self, |
| 209 | + x: Tensor, |
| 210 | + freqs_cis: Tensor, |
| 211 | + mask: Tensor, |
| 212 | + input_pos: Optional[Tensor] = None, |
| 213 | + ) -> Tensor: |
| 214 | + bsz, seqlen, _ = x.shape |
| 215 | + |
| 216 | + q: DTensor = self.wq(x) |
| 217 | + k: DTensor = self.wk(x) |
| 218 | + v: DTensor = self.wv(x) |
| 219 | + # We use `to_local()` to convert DTensor back to regular Tensor |
| 220 | + q, k, v = q.to_local(), k.to_local(), v.to_local() |
| 221 | + # kv_size = self.n_local_heads * self.head_dim |
| 222 | + # q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
| 223 | + |
| 224 | + q = q.view(bsz, seqlen, -1, self.head_dim) |
| 225 | + k = k.view(bsz, seqlen, -1, self.head_dim) |
| 226 | + v = v.view(bsz, seqlen, -1, self.head_dim) |
| 227 | + |
| 228 | + q = apply_rotary_emb(q, freqs_cis) |
| 229 | + k = apply_rotary_emb(k, freqs_cis) |
| 230 | + |
| 231 | + q, k, v = (x.transpose(1, 2) for x in (q, k, v)) |
| 232 | + |
| 233 | + # TODO: enable kv cache |
| 234 | + #if self.kv_cache is not None: |
| 235 | + # k, v = self.kv_cache.update(input_pos, k, v) |
| 236 | + |
| 237 | + k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) |
| 238 | + v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) |
| 239 | + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) |
| 240 | + |
| 241 | + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
| 242 | + |
| 243 | + y: DTensor = self.wo(y) |
| 244 | + # TODO: sequence parallelize this |
| 245 | + return y.full_tensor() |
| 246 | + |
| 247 | + |
| 248 | +class FeedForward(nn.Module): |
| 249 | + def __init__(self, config: TransformerArgs) -> None: |
| 250 | + super().__init__() |
| 251 | + w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) |
| 252 | + w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) |
| 253 | + w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) |
| 254 | + self.w1 = parallelize_module(w1, device_mesh, Colwise) |
| 255 | + self.w2 = parallelize_module(w2, device_mesh, Rowwise) |
| 256 | + self.w3 = parallelize_module(w3, device_mesh, Colwise) |
| 257 | + |
| 258 | + def forward(self, x: Tensor) -> Tensor: |
| 259 | + y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 260 | + # y is a DTensor with Partial placement; |
| 261 | + # we convert its placement to Replicate and convert it back to a regular |
| 262 | + # Tensor. `full_tensor` is the API that does both. |
| 263 | + # TODO: sequence parallelize this |
| 264 | + return y.full_tensor() |
| 265 | + |
| 266 | + |
| 267 | +class RMSNorm(nn.Module): |
| 268 | + def __init__(self, dim: int, eps: float = 1e-5): |
| 269 | + super().__init__() |
| 270 | + self.eps = eps |
| 271 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 272 | + |
| 273 | + def _norm(self, x): |
| 274 | + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
| 275 | + |
| 276 | + def forward(self, x: Tensor) -> Tensor: |
| 277 | + output = self._norm(x.float()).type_as(x) |
| 278 | + return output * self.weight |
0 commit comments