Skip to content

Commit e515403

Browse files
uygneftimmy-feng
andauthored
[Feature] USP: Replace SDPA with Flash Attention for memory optimization & Add Online Mode (#425)
* added flash_attn backend * fix pre commit * add usp for flash attn * online mode * clean up * modify test case --------- Co-authored-by: timmy-feng <timothy@modal.com>
1 parent 3e34e19 commit e515403

File tree

7 files changed

+912
-467
lines changed

7 files changed

+912
-467
lines changed

requirements.txt

Whitespace-only changes.

scripts/train_eagle3.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
destroy_distributed,
3636
get_dp_group,
3737
get_draft_dp_group,
38+
get_draft_sp_group,
3839
get_tp_group,
3940
init_distributed,
4041
)
@@ -335,10 +336,6 @@ def sanity_check(args: Namespace) -> None:
335336
args.draft_accumulation_steps = (
336337
args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size
337338
)
338-
if args.attention_backend == "usp":
339-
assert (
340-
args.train_hidden_states_path is not None
341-
), "train_hidden_states_path should not be None for usp"
342339

343340

344341
def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]:
@@ -410,6 +407,9 @@ def build_dataloaders(
410407
)
411408
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
412409
train_dataset = load_dataset("json", data_files=args.train_data_path)["train"]
410+
is_online = (
411+
args.train_data_path is not None and args.train_hidden_states_path is None
412+
)
413413
with rank_0_priority():
414414
train_eagle3_dataset = build_eagle3_dataset(
415415
dataset=train_dataset,
@@ -431,7 +431,7 @@ def build_dataloaders(
431431
cache_key=cache_key,
432432
)
433433

434-
if args.train_hidden_states_path is not None:
434+
if not is_online:
435435
train_eagle3_dataset = build_offline_eagle3_dataset(
436436
args.train_hidden_states_path,
437437
args.max_length,
@@ -443,7 +443,9 @@ def build_dataloaders(
443443
num_workers=args.dataloader_num_workers,
444444
shuffle=True,
445445
process_group=(
446-
get_draft_dp_group() if args.attention_backend == "usp" else get_dp_group()
446+
get_draft_dp_group()
447+
if args.attention_backend == "usp" and not is_online
448+
else get_dp_group()
447449
),
448450
is_vlm=args.is_vlm,
449451
)
@@ -473,7 +475,7 @@ def build_dataloaders(
473475
shuffle=False,
474476
process_group=(
475477
get_draft_dp_group()
476-
if args.attention_backend == "usp"
478+
if args.attention_backend == "usp" and not is_online
477479
else get_dp_group()
478480
),
479481
is_vlm=args.is_vlm,
@@ -630,13 +632,56 @@ def record_metrcs(
630632
tracker.log(logdict, step=global_step)
631633

632634

633-
def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor:
635+
def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor:
634636
"""
635-
Get the data shard from the tensor.
637+
Process: TP split -> Pad to Max Len -> SP gather.
636638
"""
637-
tp_size = dist.get_world_size(get_tp_group())
638-
tp_rank = dist.get_rank(get_tp_group())
639-
return tensor.chunk(tp_size, dim=0)[tp_rank]
639+
# 1. TP: Slice the tensor along the batch dimension
640+
tp_group = get_tp_group()
641+
tp_size = dist.get_world_size(tp_group)
642+
tp_rank = dist.get_rank(tp_group)
643+
644+
local_tp_shard = tensor.chunk(tp_size, dim=0)[tp_rank]
645+
646+
# 2. SP: Handle dynamic sequence lengths and Gather
647+
sp_group = get_draft_sp_group()
648+
649+
if sp_group is not None and dist.get_world_size(sp_group) > 1:
650+
sp_world_size = dist.get_world_size(sp_group)
651+
local_seq_len = local_tp_shard.size(sp_dim)
652+
653+
# Find global max sequence length in SP group
654+
len_tensor = torch.tensor(
655+
[local_seq_len], device=local_tp_shard.device, dtype=torch.long
656+
)
657+
dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group)
658+
max_seq_len = len_tensor.item()
659+
660+
# Pad local tensor if necessary
661+
# Shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1
662+
if local_seq_len < max_seq_len:
663+
pad_size = max_seq_len - local_seq_len
664+
665+
pad_config = [0] * (local_tp_shard.ndim * 2)
666+
667+
pad_idx = (local_tp_shard.ndim - 1 - sp_dim) * 2 + 1
668+
pad_config[pad_idx] = pad_size
669+
670+
# Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
671+
local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0)
672+
else:
673+
local_tp_shard_padded = local_tp_shard
674+
675+
gathered_shards = [
676+
torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)
677+
]
678+
dist.all_gather(
679+
gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group
680+
)
681+
682+
return torch.cat(gathered_shards, dim=sp_dim)
683+
684+
return local_tp_shard
640685

641686

642687
def main():

specforge/layers/ring/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# adapt from https://github.com/feifeibear/long-context-attention/tree/main/yunchang
2+
from .ring_flash_attn import (
3+
ring_flash_attn_func,
4+
ring_flash_attn_kvpacked_func,
5+
ring_flash_attn_qkvpacked_func,
6+
)
7+
8+
__all__ = [
9+
"ring_flash_attn_func",
10+
"ring_flash_attn_kvpacked_func",
11+
"ring_flash_attn_qkvpacked_func",
12+
]

0 commit comments

Comments
 (0)