-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathnative_sparse_attention.py
More file actions
354 lines (294 loc) · 12.5 KB
/
native_sparse_attention.py
File metadata and controls
354 lines (294 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
from __future__ import annotations
import math
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
try:
RMSNorm = nn.RMSNorm
except AttributeError:
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-8):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
self.dim = dim
def forward(self, x: torch.Tensor):
norm = x.pow(2).mean(-1, keepdim = True).add(self.eps).rsqrt()
return x * norm * self.scale
_freq_cache: dict[tuple[torch.dtype, torch.device, int], torch.Tensor] = {}
def apply_rope(x: torch.Tensor, pos: torch.Tensor):
""" Apply rotary embedding in-place on the last dim of *x* (... d)"""
d = x.size(-1)
assert d % 2 == 0, "RoPE requires even head dim"
half = d // 2
key = (x.dtype, x.device, half)
if key not in _freq_cache:
_freq_cache[key] = 1.0 / (10000 ** (torch.arange(half, device = x.device, dtype = x.dtype) / half))
freq = _freq_cache[key]
theta = pos[:, None] * freq[None, :]
sin, cos = theta.sin(), theta.cos()
x1, x2 = x[..., :half], x[..., half:]
rotated = torch.empty_like(x)
rotated[..., :half] = x1 * cos - x2 * sin
rotated[..., half:] = x1 * sin + x2 * cos
return rotated
@dataclass
class NSAConfig:
dim: int = 512
heads: int = 8
dim_head: Optional[int] = None
seq_len: int = 2048
# Branch hyper-params
local_window: int = 128 # W
block_size: int = 32 # B
stride: int = 32 # S
topk_blocks: int = 4 # K
compression: Literal["grouped_mlp", "conv1d", "avgpool"] = "grouped_mlp"
dropout: float = 0.1
use_flash: bool = True
# Gate options
gate_mode: Literal["static", "q_cond"] = "static"
gate_init: tuple[float, float, float] = (2.0, -2.0, -2.0) # local window favored as per DeepSeek paper
def __post_init__(self):
if self.dim_head is None:
self.dim_head = self.dim // self.heads
if self.dim % self.heads != 0:
raise ValueError(f"dim {self.dim} must be divisible by heads {self.heads}")
if self.block_size % self.stride:
raise ValueError(f"block_size {self.block_size} must be divisible by stride {self.stride}")
if self.topk_blocks <= 0:
raise ValueError(f"topk_blocks must be positive, got {self.topk_blocks}")
if self.local_window % 2 != 0:
raise ValueError(f"local_window {self.local_window} must be even")
class GroupedMLP(nn.Module):
"""Projects *block_size* tokens -> 1 vector with an MLP (per-block MLP)"""
def __init__(self, dim: int, block: int):
super().__init__()
self.B = block
self.proj = nn.Linear(dim * block, dim, bias = False)
self.norm = RMSNorm(dim)
self.bias = nn.Parameter(torch.zeros(1, block, dim))
def forward(self, x: torch.Tensor):
b, n, d = x.shape
pad = (self.B - n % self.B) % self.B
if pad:
x = torch.cat([x, x[:, -1:, :].expand(-1, pad, -1)], dim = 1)
x = rearrange(x, "b (w B) d -> b w B d", B = self.B)
x = x + self.bias
x = rearrange(x, "b w B d -> b w (B d)")
x = self.proj(x)
return self.norm(x)
class Conv1dCompression(nn.Module):
def __init__(self, dim: int, block: int):
super().__init__()
self.B = block
self.conv = nn.Conv1d(dim, dim, kernel_size = block, stride = block, groups = dim)
self.norm = RMSNorm(dim)
def forward(self, x: torch.Tensor):
b, n, d = x.shape
pad = (self.B - n % self.B) % self.B
if pad:
x = torch.cat([x, x[:, -1:, :].expand(-1, pad, -1)], dim = 1)
x = rearrange(x, "b n d -> b d n")
x = self.conv(x)
x = rearrange(x, "b d n -> b n d")
return self.norm(x)
class AvgPoolCompression(nn.Module):
def __init__(self, dim: int, block: int):
super().__init__()
self.B = block
self.pool = nn.AvgPool1d(kernel_size = block, stride = block)
self.norm = RMSNorm(dim)
def forward(self, x: torch.Tensor):
b, n, d = x.shape
pad = (self.B - n % self.B) % self.B
if pad:
x = torch.cat([x, x[:, -1:, :].expand(-1, pad, -1)], dim = 1)
x = rearrange(x, "b n d -> b d n")
x = self.pool(x)
x = rearrange(x, "b d n -> b n d")
return self.norm(x)
class TokenCompressor(nn.Module):
def __init__(self, cfg: NSAConfig):
super().__init__()
if cfg.compression == "grouped_mlp":
self.op = GroupedMLP(cfg.dim, cfg.block_size)
elif cfg.compression == "conv1d":
self.op = Conv1dCompression(cfg.dim, cfg.block_size)
elif cfg.compression == "avgpool":
self.op = AvgPoolCompression(cfg.block_size)
else:
raise ValueError(cfg.compression)
def forward(self, x):
return self.op(x)
class SparseMasking:
"""Pre-computes boolean masks used by the three branches."""
@staticmethod
@lru_cache(maxsize = None)
def local(seq: int, W: int, device):
idx = torch.arange(seq, device = device)
diff = idx[:, None] - idx[None, :]
return (diff >= 0) & (diff < W) # causal sliding window
@staticmethod
@lru_cache(maxsize = None)
def compressed(seq: int, comp: int, S: int, device):
t = torch.arange(seq, device = device)[:, None]
c = torch.arange(comp, device = device)[None, :]
return t // S > c # strictly past blocks
@staticmethod
def selected(q: torch.Tensor, kc: torch.Tensor, B: int, S: int, K: int):
"""
Return (b,h,n,n) mask selecting Top-K compressed blocks
q: (b,h,n,d)
kc: (b,h,B*,d)
B: block size
S: stride
K: topk blocks
"""
b, h, n, d = q.shape
tokens_per_block = B // S
# number of blocks in kc
blk_total = math.ceil(kc.size(2) / tokens_per_block)
# pad kc to full blocks
pad = blk_total * tokens_per_block - kc.size(2)
if pad:
kc = F.pad(kc, (0, 0, 0, pad)) # (b,h,B*,d)
# blk --> block index ( 0, 1, 2, ..., blk_total - 1)
kc_blk = rearrange(kc, "b h (blk t) d -> b h blk t d", t = tokens_per_block)
kc_mean = kc_blk.mean(3) # (b,h,blk,d)
logits = torch.einsum("bhid,bhjd->bhij", q, kc_mean) # (b,h,n,blk)
tok_blk = torch.div(torch.arange(n, device=q.device), B, rounding_mode="floor")
blk_id = torch.arange(blk_total, device=q.device)
# causal mask (strictly past blocks)
causal = tok_blk.view(1,1,-1,1) >= blk_id.view(1,1,1,-1)
diag = tok_blk.view(1,1,-1,1) != blk_id.view(1,1,1,-1)
logits = logits.masked_fill(~(causal & diag), -1e9)
# topk selection
attn = logits.softmax(dim=-1) # (b,h,n,blk_total)
topk = attn.topk(k=min(K, blk_total), dim=-1) # (values, indices)
blk_mask = torch.zeros_like(attn, dtype=torch.bool)
blk_mask.scatter_(-1, topk.indices, (topk.values > 1e-5))
blk_mask = blk_mask.unsqueeze(-1).repeat(1,1,1,1,B) # (b,h,n,blk_total,B)
blk_mask = blk_mask.view(b, h, n, blk_total * B) # (b,h,n,n_padded)
return blk_mask[..., :n] # crop padding
class NativeSparseAttention(nn.Module):
def __init__(self, cfg: NSAConfig):
super().__init__()
self.cfg = cfg
self.scale = cfg.dim_head ** -0.5
# projections
self.to_qkv = nn.Linear(cfg.dim, cfg.dim * 3, bias=False) # q, k, v
self.comp = TokenCompressor(cfg)
self.to_kvc = nn.Linear(cfg.dim, cfg.dim * 2, bias = False) # kc, vc
self.out_proj = nn.Linear(cfg.dim, cfg.dim, bias = False)
self.dropout = nn.Dropout(cfg.dropout)
# gating
if cfg.gate_mode == "static":
init = torch.tensor(cfg.gate_init, dtype=torch.float32)
self.gate = nn.Parameter(init.repeat(cfg.heads,1)) # (H, 3)
self.gate_proj = None
else: # query-conditioned
self.gate = nn.Parameter(torch.zeros(cfg.heads, 3)) # bias term
self.gate_proj = nn.Linear(cfg.dim_head, 3, bias = False)
def forward(self, x: torch.Tensor):
"""Input (b,n,dim) -> (b,n,dim)"""
b, n, _ = x.shape
cfg, H = self.cfg, self.cfg.heads
device = x.device
# positions
pos = torch.arange(n, device = device)
# shared QKV
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = rearrange(q, "b n (h d) -> b h n d", h=H)
k = rearrange(k, "b n (h d) -> b h n d", h=H)
v = rearrange(v, "b n (h d) -> b h n d", h=H)
apply_rope(q, pos)
apply_rope(k, pos)
# compressed tokens
xc = self.comp(x)
kc, vc = self.to_kvc(xc).chunk(2, dim=-1)
kc = rearrange(kc, "b n (h d) -> b h n d", h=H)
vc = rearrange(vc, "b n (h d) -> b h n d", h=H)
n_c = kc.size(2)
# branch masks (cached)
local_mask = SparseMasking.local(n, cfg.local_window, device).unsqueeze(0).unsqueeze(0)
compressed_mask = SparseMasking.compressed(n, n_c, cfg.stride, device).unsqueeze(0).unsqueeze(0)
selected_mask = SparseMasking.selected(q, kc, cfg.block_size, cfg.stride, cfg.topk_blocks)
# branch outputs
local_o = self._attend(q, k, v, local_mask)
comp_o = self._attend(q, kc, vc, None if compressed_mask.all() else compressed_mask)
sel_o = self._attend(q, k, v, selected_mask)
# gating (H, 3) or (b, h, 3)
if cfg.gate_mode == "static":
g = self.gate # (H,3)
w = g.softmax(dim=-1).view(1,H,1,3)
else:
g = self.gate + self.gate_proj(q.mean(dim=-2)) # (b,h,3)
w = g.softmax(dim=-1).unsqueeze(-2) # (b,h,1,3)
out = (w[..., 0, None] * local_o + w[..., 1, None] * comp_o + w[..., 2, None] * sel_o)
out = rearrange(out, "b h n d -> b n (h d)")
return self.out_proj(out)
def _attend(self, q, k, v, mask: Optional[torch.Tensor]):
cfg = self.cfg
if (mask is None and cfg.use_flash and torch.backends.cuda.flash_sdp_enabled()):
return F.scaled_dot_product_attention(q, k, v, dropout_p = cfg.dropout, is_causal = True)
scores = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
if mask is not None:
scores.masked_fill_(~mask, -1e9)
attn = scores.softmax(dim=-1)
attn = self.dropout(attn)
return torch.einsum("bhij,bhjd->bhid", attn, v)
class NSABlock(nn.Module):
def __init__(self, cfg: NSAConfig):
super().__init__()
self.cfg = cfg
self.norm1 = RMSNorm(cfg.dim)
self.attn = NativeSparseAttention(cfg)
self.norm2 = RMSNorm(cfg.dim)
self.ff = nn.Sequential(
nn.Linear(cfg.dim, 4 * cfg.dim), nn.GELU(), nn.Dropout(cfg.dropout),
nn.Linear(4 * cfg.dim, cfg.dim), nn.Dropout(cfg.dropout),
)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.norm1(x))
x = x + self.ff(self.norm2(x))
return x
# =============================================
# self-tests
# =============================================
def _smoke():
cfg = NSAConfig(seq_len=128)
x = torch.randn(2, cfg.seq_len, cfg.dim)
out = NSABlock(cfg)(x)
assert out.shape == x.shape
def _causal():
cfg = NSAConfig(seq_len=64, use_flash=False, dropout=0.0)
block = NSABlock(cfg)
block.eval()
x = torch.randn(1, cfg.seq_len, cfg.dim)
with torch.no_grad():
out1 = block(x)[:, -1, :]
x2 = torch.cat([x, torch.randn(1,1,cfg.dim)], 1)
out2 = block(x2)[:, -2, :]
assert torch.allclose(out1, out2, atol=1e-5)
def _flash_parity():
if not torch.backends.cuda.flash_sdp_enabled():
return
cfg = NSAConfig(seq_len=128, use_flash=False, dropout=0.0)
x = torch.randn(1, cfg.seq_len, cfg.dim, device="cuda")
block = NSABlock(cfg).cuda()
block.eval()
with torch.no_grad():
ref = block(x)
block.cfg.use_flash = True
test = block(x)
assert (ref - test).abs().max() < 1e-4
if __name__ == "__main__":
_smoke()
_causal()
_flash_parity()
print("All NSA tests passed")