Skip to content

Commit 9ef7f64

Browse files
committed
add usp for flash attn
1 parent f96f3f6 commit 9ef7f64

File tree

7 files changed

+639
-338
lines changed

7 files changed

+639
-338
lines changed

scripts/regenerate_train_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,11 @@ def main():
292292
error_samples = 0
293293

294294
# Create progress bar
295-
with open(args.input_file_path, "r") as input_file, open(
296-
args.output_file_path, "w"
297-
) as output_file_handle, open(error_file_path, "w") as error_file_handle:
295+
with (
296+
open(args.input_file_path, "r") as input_file,
297+
open(args.output_file_path, "w") as output_file_handle,
298+
open(error_file_path, "w") as error_file_handle,
299+
):
298300
executor = ThreadPoolExecutor(
299301
max_workers=args.concurrency * len(valid_server_addresses)
300302
)

scripts/train_eagle3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def sanity_check(args: Namespace) -> None:
335335
args.draft_accumulation_steps = (
336336
args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size
337337
)
338-
if args.attention_backend == "usp":
338+
if args.attention_backend in ("usp", "usp_fa"):
339339
assert (
340340
args.train_hidden_states_path is not None
341341
), "train_hidden_states_path should not be None for usp"
@@ -443,7 +443,9 @@ def build_dataloaders(
443443
num_workers=args.dataloader_num_workers,
444444
shuffle=True,
445445
process_group=(
446-
get_draft_dp_group() if args.attention_backend == "usp" else get_dp_group()
446+
get_draft_dp_group()
447+
if args.attention_backend == "usp"
448+
else get_dp_group()
447449
),
448450
is_vlm=args.is_vlm,
449451
)

specforge/layers/ring/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# adapt from https://github.com/feifeibear/long-context-attention/tree/main/yunchang
2+
from .ring_flash_attn import (
3+
ring_flash_attn_func,
4+
ring_flash_attn_kvpacked_func,
5+
ring_flash_attn_qkvpacked_func,
6+
)
7+
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
import torch
2+
from .utils import RingComm, update_out_and_lse
3+
from yunchang.kernels import select_flash_attn_impl, AttnType
4+
5+
def ring_flash_attn_forward(
6+
process_group,
7+
q: torch.Tensor,
8+
k: torch.Tensor,
9+
v: torch.Tensor,
10+
softmax_scale,
11+
dropout_p=0,
12+
causal=True,
13+
window_size=(-1, -1),
14+
softcap=0.0,
15+
alibi_slopes=None,
16+
deterministic=False,
17+
attn_type: AttnType = AttnType.FA,
18+
attn_processor=None,
19+
):
20+
comm = RingComm(process_group)
21+
22+
out = None
23+
lse = None
24+
25+
next_k, next_v = None, None
26+
27+
for step in range(comm.world_size):
28+
if step + 1 != comm.world_size:
29+
next_k: torch.Tensor = comm.send_recv(k)
30+
next_v: torch.Tensor = comm.send_recv(v)
31+
comm.commit()
32+
33+
if not causal or step <= comm.rank:
34+
fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor)
35+
block_out, block_lse = fn(
36+
q,
37+
k,
38+
v,
39+
dropout_p=dropout_p,
40+
softmax_scale=softmax_scale,
41+
causal=causal and step == 0,
42+
window_size=window_size,
43+
softcap=softcap,
44+
alibi_slopes=alibi_slopes,
45+
return_softmax=True and dropout_p > 0,
46+
)
47+
if attn_type == AttnType.SPARSE_SAGE:
48+
out, lse = block_out, block_lse
49+
else:
50+
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
51+
52+
if step + 1 != comm.world_size:
53+
comm.wait()
54+
k = next_k
55+
v = next_v
56+
57+
out = out.to(q.dtype)
58+
if attn_type != AttnType.SPARSE_SAGE:
59+
lse = lse.squeeze(dim=-1).transpose(1, 2)
60+
return out, lse
61+
62+
63+
def ring_flash_attn_backward(
64+
process_group,
65+
dout,
66+
q,
67+
k,
68+
v,
69+
out,
70+
softmax_lse,
71+
softmax_scale,
72+
dropout_p=0,
73+
causal=True,
74+
window_size=(-1, -1),
75+
softcap=0.0,
76+
alibi_slopes=None,
77+
deterministic=False,
78+
attn_type: AttnType = AttnType.FA,
79+
):
80+
kv_comm = RingComm(process_group)
81+
d_kv_comm = RingComm(process_group)
82+
dq, dk, dv = None, None, None
83+
next_dk, next_dv = None, None
84+
85+
block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
86+
block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
87+
block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
88+
89+
next_dk, next_dv = None, None
90+
next_k, next_v = None, None
91+
92+
for step in range(kv_comm.world_size):
93+
if step + 1 != kv_comm.world_size:
94+
next_k = kv_comm.send_recv(k)
95+
next_v = kv_comm.send_recv(v)
96+
kv_comm.commit()
97+
if step <= kv_comm.rank or not causal:
98+
bwd_causal = causal and step == 0
99+
fn = select_flash_attn_impl(attn_type, stage="bwd-only")
100+
fn(
101+
dout,
102+
q,
103+
k,
104+
v,
105+
out,
106+
softmax_lse,
107+
block_dq_buffer,
108+
block_dk_buffer,
109+
block_dv_buffer,
110+
dropout_p,
111+
softmax_scale,
112+
bwd_causal,
113+
window_size,
114+
softcap,
115+
alibi_slopes,
116+
deterministic,
117+
rng_state=None,
118+
)
119+
120+
if dq is None:
121+
dq = block_dq_buffer.to(torch.float32)
122+
dk = block_dk_buffer.to(torch.float32)
123+
dv = block_dv_buffer.to(torch.float32)
124+
else:
125+
dq += block_dq_buffer
126+
d_kv_comm.wait()
127+
dk = block_dk_buffer + next_dk
128+
dv = block_dv_buffer + next_dv
129+
elif step != 0:
130+
d_kv_comm.wait()
131+
dk = next_dk
132+
dv = next_dv
133+
134+
if step + 1 != kv_comm.world_size:
135+
kv_comm.wait()
136+
k = next_k
137+
v = next_v
138+
139+
next_dk = d_kv_comm.send_recv(dk)
140+
next_dv = d_kv_comm.send_recv(dv)
141+
d_kv_comm.commit()
142+
143+
d_kv_comm.wait()
144+
145+
return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
146+
147+
148+
class RingFlashAttnFunc(torch.autograd.Function):
149+
@staticmethod
150+
def forward(
151+
ctx,
152+
q,
153+
k,
154+
v,
155+
dropout_p,
156+
softmax_scale,
157+
causal,
158+
window_size,
159+
softcap,
160+
alibi_slopes,
161+
deterministic,
162+
return_softmax,
163+
group,
164+
attn_type,
165+
attn_processor,
166+
):
167+
if softmax_scale is None:
168+
softmax_scale = q.shape[-1] ** (-0.5)
169+
170+
assert alibi_slopes is None
171+
k = k.contiguous()
172+
v = v.contiguous()
173+
out, softmax_lse = ring_flash_attn_forward(
174+
group,
175+
q,
176+
k,
177+
v,
178+
softmax_scale=softmax_scale,
179+
dropout_p=dropout_p,
180+
causal=causal,
181+
window_size=window_size,
182+
softcap=softcap,
183+
alibi_slopes=alibi_slopes,
184+
deterministic=False,
185+
attn_type=attn_type,
186+
attn_processor=attn_processor,
187+
)
188+
# this should be out_padded
189+
ctx.save_for_backward(q, k, v, out, softmax_lse)
190+
ctx.dropout_p = dropout_p
191+
ctx.softmax_scale = softmax_scale
192+
ctx.causal = causal
193+
ctx.window_size = window_size
194+
ctx.softcap = softcap
195+
ctx.alibi_slopes = alibi_slopes
196+
ctx.deterministic = deterministic
197+
ctx.group = group
198+
ctx.attn_type = attn_type
199+
ctx.attn_processor = attn_processor
200+
return out if not return_softmax else (out, softmax_lse, None)
201+
202+
@staticmethod
203+
def backward(ctx, dout, *args):
204+
q, k, v, out, softmax_lse = ctx.saved_tensors
205+
dq, dk, dv = ring_flash_attn_backward(
206+
ctx.group,
207+
dout,
208+
q,
209+
k,
210+
v,
211+
out,
212+
softmax_lse,
213+
softmax_scale=ctx.softmax_scale,
214+
dropout_p=ctx.dropout_p,
215+
causal=ctx.causal,
216+
window_size=ctx.window_size,
217+
softcap=ctx.softcap,
218+
alibi_slopes=ctx.alibi_slopes,
219+
deterministic=ctx.deterministic,
220+
attn_type=ctx.attn_type,
221+
)
222+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
223+
224+
225+
def ring_flash_attn_qkvpacked_func(
226+
qkv,
227+
dropout_p=0.0,
228+
softmax_scale=None,
229+
causal=False,
230+
window_size=(-1, -1),
231+
softcap=0.0,
232+
alibi_slopes=None,
233+
deterministic=False,
234+
return_attn_probs=False,
235+
group=None,
236+
attn_type: AttnType = AttnType.FA,
237+
):
238+
return RingFlashAttnFunc.apply(
239+
qkv[:, :, 0],
240+
qkv[:, :, 1],
241+
qkv[:, :, 2],
242+
dropout_p,
243+
softmax_scale,
244+
causal,
245+
window_size,
246+
softcap,
247+
alibi_slopes,
248+
deterministic,
249+
return_attn_probs,
250+
group,
251+
attn_type,
252+
)
253+
254+
255+
def ring_flash_attn_kvpacked_func(
256+
q,
257+
kv,
258+
dropout_p=0.0,
259+
softmax_scale=None,
260+
causal=False,
261+
window_size=(-1, -1),
262+
softcap=0.0,
263+
alibi_slopes=None,
264+
deterministic=False,
265+
return_attn_probs=False,
266+
group=None,
267+
attn_type: AttnType = AttnType.FA,
268+
):
269+
return RingFlashAttnFunc.apply(
270+
q,
271+
kv[:, :, 0],
272+
kv[:, :, 1],
273+
dropout_p,
274+
softmax_scale,
275+
causal,
276+
window_size,
277+
softcap,
278+
alibi_slopes,
279+
deterministic,
280+
return_attn_probs,
281+
group,
282+
attn_type,
283+
)
284+
285+
286+
def ring_flash_attn_func(
287+
q,
288+
k,
289+
v,
290+
dropout_p=0.0,
291+
softmax_scale=None,
292+
causal=False,
293+
window_size=(-1, -1),
294+
softcap=0.0,
295+
alibi_slopes=None,
296+
deterministic=False,
297+
return_attn_probs=False,
298+
group=None,
299+
attn_type: AttnType = AttnType.FA,
300+
attn_processor=None,
301+
):
302+
return RingFlashAttnFunc.apply(
303+
q,
304+
k,
305+
v,
306+
dropout_p,
307+
softmax_scale,
308+
causal,
309+
window_size,
310+
softcap,
311+
alibi_slopes,
312+
deterministic,
313+
return_attn_probs,
314+
group,
315+
attn_type,
316+
attn_processor,
317+
)

0 commit comments

Comments
 (0)