Skip to content

[Feature] USP: Replace SDPA with Flash Attention for memory optimization & Add Online Mode#425

Merged
sleepcoo merged 6 commits intosgl-project:mainfrom
uygnef:tmp/sp
Jan 14, 2026
Merged

[Feature] USP: Replace SDPA with Flash Attention for memory optimization & Add Online Mode#425
sleepcoo merged 6 commits intosgl-project:mainfrom
uygnef:tmp/sp

Conversation

@uygnef
Copy link
Collaborator

@uygnef uygnef commented Jan 13, 2026

waiting for #400 to be merged

Motivation

To accelerate long-context training, this PR integrates Flash Attention into the Unified Sequence Parallelism (USP) framework. By replacing standard PyTorch operations with optimized kernels inside the Ring Attention loop, this implementation significantly improves memory efficiency for Eagle3 draft models.

Additionally, this PR adds support for online mode within the USP framework.

Modifications

  • Implemented LlamaUSPFlashAttention: Added a hybrid sequence parallel attention layer that combines Ulysses, Ring Attention, and Flash Attention (flash_attn_func).
  • Support online mode for Sequence Parallelism:Enabled online mode support for Sequence Parallelism.

Usage

torchrun \
    ...
    scripts/train_eagle3.py \
    ...
    --attention-backend usp_fa \
    --sp-ulysses-size $ULYSSES_SIZE \
    --sp-ring-size $RING_SIZE

Related Issues

Accuracy Test

python tests/test_layers/test_decoder.py

compare to sdpa implement bf16 diff less than 2e-2, fp16 diff less than 5e-3
image
image

Benchmark & Profiling

todo

TODO

  1. Optimal Loss Aggregation: Currently uses all_gather within the SP group. Will be optimized to local calculation + reduce_sum to save VRAM.
  2. Enhance Online Mode: Set draft micro-batch size.

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@sleepcoo sleepcoo marked this pull request as ready for review January 13, 2026 06:24
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@uygnef uygnef force-pushed the tmp/sp branch 2 times, most recently from 11a2e13 to bf7f659 Compare January 13, 2026 12:13
@uygnef uygnef changed the title [feature] Sequence Parallelism: Replace SDPA with LlamaUSPFlashAttention for memory optimization [Feature] USP: Replace SDPA with Flash Attention for memory optimization & Add Online Mode Jan 13, 2026
@sleepcoo sleepcoo merged commit e515403 into sgl-project:main Jan 14, 2026
2 of 5 checks passed
@jiapingW
Copy link
Collaborator

I tried train seqlen=65536 with ring=8, ulysses=2 and it'll cost 94G. The setting with ring=4, ulysses=2 and it'll also cost 94G. Does it normal?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants