|
1 | 1 | import math |
| 2 | +import warnings |
2 | 3 | from typing import List, Optional, Tuple |
3 | 4 |
|
4 | 5 | import torch |
|
23 | 24 | from ...distributed import get_sp_ring_group, get_sp_ulysses_group |
24 | 25 | from .base import Eagle3DraftModel |
25 | 26 |
|
| 27 | +try: |
| 28 | + from flash_attn import flash_attn_func |
| 29 | +except: |
| 30 | + warnings.warn( |
| 31 | + "flash_attn is not found, please install flash_attn if you want to use the flash attention backend" |
| 32 | + ) |
| 33 | + flash_attn_func = None |
| 34 | + |
26 | 35 |
|
27 | 36 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask |
28 | 37 | def _make_causal_mask( |
@@ -94,12 +103,12 @@ def rotate_half(x): |
94 | 103 |
|
95 | 104 |
|
96 | 105 | @torch.compile(dynamic=True) |
97 | | -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
| 106 | +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
98 | 107 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. |
99 | 108 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] |
100 | 109 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] |
101 | | - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] |
102 | | - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] |
| 110 | + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] |
| 111 | + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] |
103 | 112 | q_embed = (q * cos) + (rotate_half(q) * sin) |
104 | 113 | k_embed = (k * cos) + (rotate_half(k) * sin) |
105 | 114 | return q_embed, k_embed |
@@ -1170,6 +1179,120 @@ def forward( |
1170 | 1179 | return attn_output |
1171 | 1180 |
|
1172 | 1181 |
|
| 1182 | +class LlamaFlashAttention(LlamaAttention): |
| 1183 | + """ |
| 1184 | + Attention layer implemented with flash attention. We keep the parameters consistent with LlamaAttention. |
| 1185 | + The used parameters are: |
| 1186 | + - hidden_states: input hidden states |
| 1187 | + - position_ids: position ids |
| 1188 | + - cache_hidden: manual cache used for storing past key and value states |
| 1189 | + """ |
| 1190 | + |
| 1191 | + def forward( |
| 1192 | + self, |
| 1193 | + hidden_states: torch.Tensor, |
| 1194 | + cache_hidden: Optional[List[torch.Tensor]] = None, |
| 1195 | + attention_mask: Optional[torch.Tensor] = None, |
| 1196 | + position_ids: Optional[torch.LongTensor] = None, |
| 1197 | + past_key_values: Optional[Cache] = None, |
| 1198 | + output_attentions: bool = False, |
| 1199 | + use_cache: bool = False, |
| 1200 | + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 1201 | + bsz, q_len, _ = hidden_states.size() |
| 1202 | + |
| 1203 | + query_states = self.q_proj(hidden_states) |
| 1204 | + key_states = self.k_proj(hidden_states) |
| 1205 | + value_states = self.v_proj(hidden_states) |
| 1206 | + |
| 1207 | + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) |
| 1208 | + key_states = key_states.view( |
| 1209 | + bsz, q_len, self.num_key_value_heads, self.head_dim |
| 1210 | + ) |
| 1211 | + value_states = value_states.view( |
| 1212 | + bsz, q_len, self.num_key_value_heads, self.head_dim |
| 1213 | + ) |
| 1214 | + |
| 1215 | + lck = 0 if cache_hidden is None else len(cache_hidden[0]) |
| 1216 | + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): |
| 1217 | + cos, sin = self.rotary_emb(query_states, position_ids + lck) |
| 1218 | + cos, sin = cos.to(query_states.device), sin.to(query_states.device) |
| 1219 | + query_states, key_states = apply_multimodal_rotary_pos_emb( |
| 1220 | + query_states, |
| 1221 | + key_states, |
| 1222 | + cos, |
| 1223 | + sin, |
| 1224 | + self.config.rope_scaling["mrope_section"], |
| 1225 | + unsqueeze_dim=2, |
| 1226 | + ) |
| 1227 | + else: |
| 1228 | + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) |
| 1229 | + cos, sin = cos.to(query_states.device), sin.to(query_states.device) |
| 1230 | + query_states, key_states = apply_rotary_pos_emb( |
| 1231 | + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 |
| 1232 | + ) |
| 1233 | + |
| 1234 | + if cache_hidden is not None: |
| 1235 | + cache_hidden[0] = cache_hidden[0] + [key_states] |
| 1236 | + cache_hidden[1] = cache_hidden[1] + [value_states] |
| 1237 | + |
| 1238 | + cache_k = cache_hidden[0] |
| 1239 | + cache_v = cache_hidden[1] |
| 1240 | + else: |
| 1241 | + cache_k = [key_states] |
| 1242 | + cache_v = [value_states] |
| 1243 | + |
| 1244 | + k0 = cache_k[0] |
| 1245 | + v0 = cache_v[0] |
| 1246 | + |
| 1247 | + assert ( |
| 1248 | + flash_attn_func is not None |
| 1249 | + ), "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" |
| 1250 | + attn_output, lse, _ = flash_attn_func( |
| 1251 | + query_states, |
| 1252 | + k0, |
| 1253 | + v0, |
| 1254 | + dropout_p=0.0, |
| 1255 | + softmax_scale=1.0 / math.sqrt(self.head_dim), |
| 1256 | + causal=True, |
| 1257 | + return_attn_probs=True, |
| 1258 | + ) |
| 1259 | + lse = lse.transpose(1, 2) |
| 1260 | + |
| 1261 | + lck = len(cache_k) |
| 1262 | + if lck > 1: |
| 1263 | + q_shape_expanded = ( |
| 1264 | + bsz, |
| 1265 | + q_len, |
| 1266 | + self.num_key_value_heads, |
| 1267 | + self.num_key_value_groups, |
| 1268 | + self.head_dim, |
| 1269 | + ) |
| 1270 | + attn_outputs = [attn_output.view(q_shape_expanded)] |
| 1271 | + lses = [lse.view(q_shape_expanded[:-1])] |
| 1272 | + |
| 1273 | + for i in range(1, lck): |
| 1274 | + ki = cache_k[i].unsqueeze(-2) |
| 1275 | + qi = query_states.view(q_shape_expanded) |
| 1276 | + vi = cache_v[i].unsqueeze(-2) |
| 1277 | + |
| 1278 | + attn_outputs.append(vi) |
| 1279 | + lses.append((qi * ki).sum(-1) / math.sqrt(self.head_dim)) |
| 1280 | + |
| 1281 | + lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1) |
| 1282 | + attn_output = sum( |
| 1283 | + attn_outputi * torch.exp(lsei - lse).unsqueeze(-1) |
| 1284 | + for attn_outputi, lsei in zip(attn_outputs, lses) |
| 1285 | + ) |
| 1286 | + # lse is fp32, downcast attn_output back |
| 1287 | + attn_output = attn_output.to(self.o_proj.weight.dtype) |
| 1288 | + |
| 1289 | + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) |
| 1290 | + |
| 1291 | + attn_output = self.o_proj(attn_output) |
| 1292 | + |
| 1293 | + return attn_output |
| 1294 | + |
| 1295 | + |
1173 | 1296 | class LlamaMLP(nn.Module): |
1174 | 1297 | def __init__(self, config): |
1175 | 1298 | super().__init__() |
@@ -1245,6 +1368,8 @@ def __init__(self, config, attention_backend: str = "sdpa"): |
1245 | 1368 | elif attention_backend == "flex_attention": |
1246 | 1369 | print_with_rank("Using flex attention on draft model training!") |
1247 | 1370 | self.self_attn = LlamaFlexAttention(config=config) |
| 1371 | + elif attention_backend == "fa": |
| 1372 | + self.self_attn = LlamaFlashAttention(config=config) |
1248 | 1373 | else: |
1249 | 1374 | raise ValueError(f"Unknown attention backend {attention_backend}") |
1250 | 1375 |
|
|
0 commit comments