|
| 1 | +# coding=utf-8 |
| 2 | +# Adapted from |
| 3 | +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py |
| 4 | +# Copyright 2023 The vLLM team. |
| 5 | +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. |
| 6 | +# |
| 7 | +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX |
| 8 | +# and OPT implementations in this library. It has been modified from its |
| 9 | +# original forms to accommodate minor architectural differences compared |
| 10 | +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. |
| 11 | +# |
| 12 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 13 | +# you may not use this file except in compliance with the License. |
| 14 | +# You may obtain a copy of the License at |
| 15 | +# |
| 16 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 17 | +# |
| 18 | +# Unless required by applicable law or agreed to in writing, software |
| 19 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 20 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 21 | +# See the License for the specific language governing permissions and |
| 22 | +# limitations under the License. |
| 23 | +"""Rotary Positional Embeddings.""" |
| 24 | +from typing import Tuple, Union |
| 25 | + |
| 26 | +import torch |
| 27 | +import torch.nn as nn |
| 28 | + |
| 29 | +from vllm import pos_encoding_ops |
| 30 | + |
| 31 | + |
| 32 | +class RotaryEmbedding(nn.Module): |
| 33 | + """Original rotary positional embedding.""" |
| 34 | + |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + head_size: int, |
| 38 | + rotary_dim: int, |
| 39 | + max_position_embeddings: int, |
| 40 | + base: int, |
| 41 | + is_neox_style: bool, |
| 42 | + ) -> None: |
| 43 | + super().__init__() |
| 44 | + self.head_size = head_size |
| 45 | + self.rotary_dim = rotary_dim |
| 46 | + self.max_position_embeddings = max_position_embeddings |
| 47 | + self.base = base |
| 48 | + self.is_neox_style = is_neox_style |
| 49 | + |
| 50 | + cache = self._compute_cos_sin_cache() |
| 51 | + cache = cache.to(torch.get_default_dtype()) |
| 52 | + self.register_buffer("cos_sin_cache", cache, persistent=False) |
| 53 | + |
| 54 | + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: |
| 55 | + """Compute the inverse frequency.""" |
| 56 | + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. |
| 57 | + # However, we use `torch.arange(..., dtype=torch.float)` instead to |
| 58 | + # avoid numerical issues with large base values (e.g., 10000000). |
| 59 | + # This may cause a slight numerical difference between the HF |
| 60 | + # implementation and ours. |
| 61 | + # NOTE(woosuk): To exactly match the HF implementation, we need to |
| 62 | + # use CPU to compute the cache and then move it to GPU. However, we |
| 63 | + # create the cache on GPU for faster initialization. This may cause |
| 64 | + # a slight numerical difference between the HF implementation and ours. |
| 65 | + inv_freq = 1.0 / (base**(torch.arange( |
| 66 | + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / |
| 67 | + self.rotary_dim)) |
| 68 | + return inv_freq |
| 69 | + |
| 70 | + def _compute_cos_sin_cache(self) -> torch.Tensor: |
| 71 | + """Compute the cos and sin cache.""" |
| 72 | + inv_freq = self._compute_inv_freq(self.base) |
| 73 | + t = torch.arange(self.max_position_embeddings, |
| 74 | + dtype=torch.float, |
| 75 | + device="cuda") |
| 76 | + |
| 77 | + freqs = torch.einsum("i,j -> ij", t, inv_freq) |
| 78 | + cos = freqs.cos() |
| 79 | + sin = freqs.sin() |
| 80 | + cache = torch.cat((cos, sin), dim=-1) |
| 81 | + return cache |
| 82 | + |
| 83 | + def forward( |
| 84 | + self, |
| 85 | + positions: torch.Tensor, |
| 86 | + query: torch.Tensor, |
| 87 | + key: torch.Tensor, |
| 88 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 89 | + # pos_encoding_ops.rotary_embedding() is an in-place operation that |
| 90 | + # updates the query and key tensors. |
| 91 | + pos_encoding_ops.rotary_embedding(positions, query, key, |
| 92 | + self.head_size, self.cos_sin_cache, |
| 93 | + self.is_neox_style) |
| 94 | + return query, key |
| 95 | + |
| 96 | + |
| 97 | +class LinearScalingRotaryEmbedding(RotaryEmbedding): |
| 98 | + """RotaryEmbedding extended with linear scaling. |
| 99 | +
|
| 100 | + Credits to the Reddit user /u/kaiokendev |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__( |
| 104 | + self, |
| 105 | + head_size: int, |
| 106 | + rotary_dim: int, |
| 107 | + max_position_embeddings: int, |
| 108 | + base: int, |
| 109 | + is_neox_style: bool, |
| 110 | + scaling_factor: float, |
| 111 | + ) -> None: |
| 112 | + self.scaling_factor = scaling_factor |
| 113 | + super().__init__(head_size, rotary_dim, max_position_embeddings, base, |
| 114 | + is_neox_style) |
| 115 | + |
| 116 | + def _compute_cos_sin_cache(self) -> torch.Tensor: |
| 117 | + inv_freq = self._compute_inv_freq(self.base) |
| 118 | + # NOTE(woosuk): self.max_position_embeddings is the original |
| 119 | + # maximum length before applying the rope scaling. |
| 120 | + # Thus, the maximum length after applying the rope scaling is |
| 121 | + # self.max_position_embeddings * self.scaling_factor. |
| 122 | + max_len = self.max_position_embeddings * self.scaling_factor |
| 123 | + t = torch.arange(max_len, dtype=torch.float, device="cuda") |
| 124 | + t = t / self.scaling_factor |
| 125 | + |
| 126 | + freqs = torch.einsum("i,j -> ij", t, inv_freq) |
| 127 | + cos = freqs.cos() |
| 128 | + sin = freqs.sin() |
| 129 | + cache = torch.cat((cos, sin), dim=-1) |
| 130 | + return cache |
| 131 | + |
| 132 | + |
| 133 | +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): |
| 134 | + """RotaryEmbedding extended with Dynamic NTK scaling. |
| 135 | +
|
| 136 | + Credits to the Reddit users /u/bloc97 and /u/emozilla |
| 137 | + """ |
| 138 | + |
| 139 | + def __init__( |
| 140 | + self, |
| 141 | + head_size: int, |
| 142 | + rotary_dim: int, |
| 143 | + max_position_embeddings: int, |
| 144 | + base: int, |
| 145 | + is_neox_style: bool, |
| 146 | + scaling_factor: float, |
| 147 | + ) -> None: |
| 148 | + self.scaling_factor = scaling_factor |
| 149 | + super().__init__(head_size, rotary_dim, max_position_embeddings, base, |
| 150 | + is_neox_style) |
| 151 | + |
| 152 | + def _compute_cos_sin_cache(self) -> torch.Tensor: |
| 153 | + # NOTE(woosuk): self.max_position_embeddings is the original |
| 154 | + # maximum length before applying the rope scaling. |
| 155 | + # Thus, the maximum length after applying the rope scaling is |
| 156 | + # self.max_position_embeddings * self.scaling_factor. |
| 157 | + max_len = self.max_position_embeddings * self.scaling_factor |
| 158 | + base = self.base * ( |
| 159 | + (self.scaling_factor * max_len / self.max_position_embeddings) - |
| 160 | + (self.scaling_factor - 1))**(self.rotary_dim / |
| 161 | + (self.rotary_dim - 2)) |
| 162 | + inv_freq = self._compute_inv_freq(base) |
| 163 | + t = torch.arange(max_len, dtype=torch.float, device="cuda") |
| 164 | + |
| 165 | + freqs = torch.einsum("i,j -> ij", t, inv_freq) |
| 166 | + cos = freqs.cos() |
| 167 | + sin = freqs.sin() |
| 168 | + cache = torch.cat((cos, sin), dim=-1) |
| 169 | + return cache |
0 commit comments