Skip to content

Commit 5185d1e

Browse files
committed
feat(dflash): add flex_attention with BlockMask support
- Add attention_backend parameter to OnlineDFlashModel - Implement BlockMask creation for flex_attention optimization - Default to flex_attention backend (~30% speedup) - Add iter_time display in training progress bar - Clean up comments and simplify code
1 parent ee29561 commit 5185d1e

File tree

3 files changed

+130
-135
lines changed

3 files changed

+130
-135
lines changed

examples/run_qwen3_8b_dflash_online.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
66
export SPECFORGE_DATA_NUM_PROC=32
77
NUM_GPUS=${1:-1}
88

9+
ATTENTION_BACKEND=${2:-flex_attention}
10+
911
torchrun \
1012
--standalone \
1113
--nproc_per_node $NUM_GPUS \
@@ -19,6 +21,7 @@ torchrun \
1921
--learning-rate 1e-4 \
2022
--max-length 2048 \
2123
--chat-template qwen \
24+
--attention-backend $ATTENTION_BACKEND \
2225
--log-interval 50 \
2326
--save-interval 1000 \
2427
--report-to wandb \

scripts/train_dflash.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def parse_args():
5757
default=None,
5858
help="MASK token ID. If not provided, auto-detect from tokenizer.",
5959
)
60+
model_group.add_argument(
61+
"--attention-backend",
62+
type=str,
63+
default="flex_attention",
64+
choices=["eager", "sdpa", "flex_attention"],
65+
help="Attention backend for draft model.",
66+
)
6067

6168
dataset_group = parser.add_argument_group("dataset")
6269
dataset_group.add_argument("--train-data-path", type=str, required=True)
@@ -133,6 +140,10 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
133140
draft_config.num_target_layers = target_config.num_hidden_layers
134141
print_on_rank0("Auto-generated draft config from target model")
135142

143+
# Set attention implementation based on backend
144+
draft_config._attn_implementation = args.attention_backend
145+
print_on_rank0(f"Using attention backend: {args.attention_backend}")
146+
136147
draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16)
137148

138149
# Set capture layers for target model based on draft model config
@@ -344,6 +355,7 @@ def main():
344355
target_embed_tokens=target_components.embed_tokens,
345356
block_size=draft_model.block_size,
346357
mask_token_id=mask_token_id,
358+
attention_backend=args.attention_backend,
347359
)
348360

349361
dflash_model = FSDP(
@@ -436,6 +448,7 @@ def main():
436448
{
437449
"loss": f"{loss.item():.4f}",
438450
"acc": f"{accuracy.item():.4f}",
451+
"iter_time": f"{elapsed:.2f}s",
439452
}
440453
)
441454

specforge/core/dflash.py

Lines changed: 114 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
# coding=utf-8
22
"""DFlash Training Wrapper."""
33

4-
from typing import Tuple
4+
from typing import Optional, Tuple
55

66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
99

1010
from specforge.modeling.draft.dflash import DFlashDraftModel
1111

12+
try:
13+
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
14+
15+
FLEX_ATTENTION_AVAILABLE = True
16+
except ImportError:
17+
FLEX_ATTENTION_AVAILABLE = False
18+
BlockMask = None
19+
create_block_mask = None
20+
1221

1322
class OnlineDFlashModel(nn.Module):
1423
"""DFlash online training wrapper with block-wise CE loss."""
@@ -20,17 +29,24 @@ def __init__(
2029
target_embed_tokens: nn.Module,
2130
mask_token_id: int,
2231
block_size: int = 16,
32+
attention_backend: str = "flex_attention",
2333
):
2434
super().__init__()
2535
self.draft_model = draft_model
2636
self.lm_head = target_lm_head
2737
self.embed_tokens = target_embed_tokens
2838
self.block_size = block_size
2939
self.mask_token_id = mask_token_id
40+
self.attention_backend = attention_backend
41+
42+
# Cache for BlockMask
43+
self._cached_block_mask: Optional[BlockMask] = None
44+
self._cached_seq_len: Optional[int] = None
45+
self._cached_bsz: Optional[int] = None
46+
self._cached_num_heads: Optional[int] = None
3047

3148
def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor:
3249
"""Prepare noise input: first token of each block is real, rest are MASK."""
33-
# input_ids: [bsz, seq_len]
3450
seq_len = input_ids.shape[1]
3551
device = input_ids.device
3652

@@ -42,88 +58,139 @@ def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor:
4258

4359
return noise_input_ids
4460

61+
def _get_or_create_block_mask(
62+
self, bsz: int, num_heads: int, q_len: int, kv_len: int, device: torch.device
63+
) -> "BlockMask":
64+
"""Get cached BlockMask or create a new one."""
65+
if (
66+
self._cached_block_mask is not None
67+
and self._cached_seq_len == q_len
68+
and self._cached_bsz == bsz
69+
and self._cached_num_heads == num_heads
70+
):
71+
return self._cached_block_mask
72+
73+
block_size = self.block_size
74+
75+
def dflash_mask_fn(b, h, q_idx, kv_idx):
76+
L = q_len
77+
is_ctx = kv_idx < L
78+
q_block = q_idx // block_size
79+
k_block_ctx = kv_idx // block_size
80+
k_block_noise = (kv_idx - L) // block_size
81+
ctx_visible = is_ctx & (k_block_ctx < q_block)
82+
noise_visible = (~is_ctx) & (k_block_noise == q_block)
83+
return ctx_visible | noise_visible
84+
85+
block_mask = create_block_mask(
86+
dflash_mask_fn,
87+
B=bsz,
88+
H=num_heads,
89+
Q_LEN=q_len,
90+
KV_LEN=kv_len,
91+
device=device,
92+
)
93+
94+
self._cached_block_mask = block_mask
95+
self._cached_seq_len = q_len
96+
self._cached_bsz = bsz
97+
self._cached_num_heads = num_heads
98+
99+
return block_mask
100+
101+
def _create_parallel_attention_mask(
102+
self, seq_len: int, device: torch.device
103+
) -> torch.Tensor:
104+
"""
105+
Create [L, 2L] attention mask for parallel training.
106+
- Left half (ctx): Q can see K_ctx if K's block < Q's block
107+
- Right half (noise): Q can see K_noise if same block (bidirectional)
108+
"""
109+
indices = torch.arange(seq_len, device=device)
110+
block_ids = indices // self.block_size
111+
112+
q_block_ids = block_ids.unsqueeze(1)
113+
k_block_ids = block_ids.unsqueeze(0)
114+
115+
ctx_mask = k_block_ids < q_block_ids
116+
noise_mask = q_block_ids == k_block_ids
117+
118+
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1)
119+
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
120+
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
121+
122+
return full_mask
123+
45124
def forward(
46125
self,
47126
input_ids: torch.Tensor,
48127
attention_mask: torch.Tensor,
49128
hidden_states: torch.Tensor,
50129
loss_mask: torch.Tensor,
51130
) -> Tuple[torch.Tensor, torch.Tensor]:
52-
"""
53-
Parallel Block-wise training forward pass.
54-
55-
Uses a global attention mask to process all blocks in parallel without
56-
modifying the underlying modeling code.
57-
"""
131+
"""Parallel block-wise training forward pass."""
58132
bsz, seq_len = input_ids.shape
59133
device = input_ids.device
60134

61135
# Truncate to multiple of block_size
62136
n_blocks = seq_len // self.block_size
63137
effective_len = n_blocks * self.block_size
64138
input_ids = input_ids[:, :effective_len]
65-
# hidden_states here is the RAW target hidden states (before projection)
66139
hidden_states = hidden_states[:, :effective_len, :]
67140
loss_mask = loss_mask[:, :effective_len]
68-
# Original attention mask is typically just 1s for valid tokens
69141
attention_mask = attention_mask[:, :effective_len]
70142

71-
# 2. Prepare Inputs
143+
# Prepare inputs
72144
noise_input_ids = self.prepare_noise_input(input_ids)
73145
noise_embedding = self.embed_tokens(noise_input_ids)
74146

75-
# 3. Construct Parallel Training Position IDs
76-
# We need Position IDs for K which has length 2*L (Context + Noise)
77-
# Context part: 0..L-1
78-
# Noise part: 0..L-1
79-
# This ensures that Noise at pos i uses the same RoPE embedding as Context at pos i
147+
# Position IDs: [ctx_pos, noise_pos] both 0..L-1
80148
pos_seq = torch.arange(effective_len, device=device)
81-
# shape: [1, 2*L] -> [bsz, 2*L]
82149
position_ids = torch.cat([pos_seq, pos_seq], dim=0).unsqueeze(0).expand(bsz, -1)
83150

84-
# 4. Construct Parallel Attention Mask
85-
# The modeling code will internally concat K = [K_ctx, K_noise]
86-
# So K has length 2*L. Q has length L (from Noise).
87-
# We need a mask of shape [L, 2*L]
88-
dflash_attn_mask = self._create_parallel_attention_mask(effective_len, device)
89-
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
90-
# Expand to batch size: [bsz, 1, L, 2*L]
91-
# Note: transformers usually expects [bsz, 1, Q, K]
92-
dflash_attn_mask = (
93-
dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1)
94-
)
95-
96-
# 5. Parallel Forward Pass
97-
# efficient: one single forward pass for the whole sequence
151+
# Construct attention mask
152+
if (
153+
self.attention_backend == "flex_attention"
154+
and FLEX_ATTENTION_AVAILABLE
155+
and create_block_mask is not None
156+
):
157+
num_heads = self.draft_model.config.num_attention_heads
158+
dflash_attn_mask = self._get_or_create_block_mask(
159+
bsz=bsz,
160+
num_heads=num_heads,
161+
q_len=effective_len,
162+
kv_len=effective_len * 2,
163+
device=device,
164+
)
165+
else:
166+
dflash_attn_mask = self._create_parallel_attention_mask(
167+
effective_len, device
168+
)
169+
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
170+
dflash_attn_mask = (
171+
dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1)
172+
)
173+
174+
# Forward pass
98175
hidden = self.draft_model(
99-
position_ids=position_ids, # [bsz, 2*L] (used for RoPE)
100-
noise_embedding=noise_embedding, # [bsz, L, H] (Query source)
101-
target_hidden=hidden_states, # [bsz, L, H] (Context source)
102-
attention_mask=dflash_attn_mask, # [bsz, 1, L, 2*L]
176+
position_ids=position_ids,
177+
noise_embedding=noise_embedding,
178+
target_hidden=hidden_states,
179+
attention_mask=dflash_attn_mask,
103180
)
104181

105-
# 6. Compute Loss
106-
# Create mask for valid loss positions (skip block 0, skip block starts)
182+
# Compute loss (skip block 0 and block starts)
107183
dflash_loss_mask_base = create_dflash_loss_mask(
108184
effective_len, self.block_size, device
109185
)
110186
combined_mask = loss_mask * dflash_loss_mask_base.unsqueeze(0)
111187

112-
# hidden[i] predicts input_ids[i] (based on DFlash design where noise[i] is input to predict target[i])
113-
# However, check design:
114-
# "hidden[i] predicts token[i] (directly corresponding), not token[i+1]"
115-
# "noise_input[i] is MASK, we want to predict input_ids[i]"
116-
# So logits at index i should be compared to labels at index i.
117-
118188
logits = self.lm_head(hidden)
119189

120-
# Calculate Loss
121-
# Flatten
122190
logits_flat = logits.reshape(-1, logits.size(-1))
123191
labels_flat = input_ids.reshape(-1)
124192
mask_flat = combined_mask.reshape(-1)
125193

126-
# Optimization: only compute CE on valid tokens
127194
active_indices = mask_flat > 0.5
128195
active_logits = logits_flat[active_indices]
129196
active_labels = labels_flat[active_indices]
@@ -138,99 +205,11 @@ def forward(
138205

139206
return loss, accuracy
140207

141-
def _create_parallel_attention_mask(
142-
self, seq_len: int, device: torch.device
143-
) -> torch.Tensor:
144-
"""
145-
Creates the [L, 2L] mask for parallel training.
146-
Rows: Query (Noise) indices 0..L-1
147-
Cols: Key indices 0..2L-1 (First L are Context, Next L are Noise)
148-
149-
Logic for Query at index i (belonging to block B = i // block_size):
150-
1. Can see Context (Cols 0..L-1):
151-
- Can see context of PREVIOUS blocks.
152-
- Range: [0, B * block_size)
153-
2. Can see Noise (Cols L..2L-1):
154-
- Can see noise of CURRENT block up to self.
155-
- Range: [L + B * block_size, L + i]
156-
- (Inclusive of i because causal mask usually allows seeing self)
157-
"""
158-
# Block indices for each position [0, 0, ..., 1, 1, ...]
159-
indices = torch.arange(seq_len, device=device)
160-
block_ids = indices // self.block_size
161-
162-
# 1. Context Mask (L x L) - Left half of K
163-
# Q[i] can see K_ctx[j] if Block(Q[i]) > Block(K_ctx[j])
164-
# Actually, Block(Q[i]) can see Context of all previous blocks.
165-
# It implies Block(K_ctx[j]) < Block(Q[i])
166-
# Wait, strictly: Block B needs context from 0..(B*16).
167-
# So it sees all K_ctx where index < B * 16.
168-
# Which is equivalent to block_ids[j] < block_ids[i].
169-
170-
# Broadcast logic
171-
# shape [L, 1]
172-
q_block_ids = block_ids.unsqueeze(1)
173-
# shape [1, L]
174-
k_block_ids = block_ids.unsqueeze(0)
175-
176-
# Mask: 1 if K's block is strictly less than Q's block
177-
# This gives access to all PREVIOUS blocks' context.
178-
ctx_mask = k_block_ids < q_block_ids
179-
180-
# 2. Noise Mask (L x L) - Right half of K
181-
# Standard Causal Mask WITHIN the same block.
182-
# Q[i] can see K_noise[j] if:
183-
# a) Same Block: Block(Q[i]) == Block(K_noise[j])
184-
# b) Causal: j <= i
185-
# Different blocks cannot see each other's noise.
186-
187-
same_block = q_block_ids == k_block_ids
188-
189-
noise_mask = same_block
190-
191-
# Combine [Ctx_Mask, Noise_Mask]
192-
# Shape [L, 2L]
193-
# We need float mask for attention: 0.0 for allow, -inf for mask
194-
# Transformers usually handles boolean masks by converting them,
195-
# but explicit MinValue is safer if passing to generic attention.
196-
# However, most HF models accept boolean [batch, 1, Q, K].
197-
198-
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1)
199-
200-
# Check standard HF format: usually 1 for keep, 0 for mask, or boolean.
201-
# Qwen3 implementation likely uses typical attention_mask handling.
202-
# Let's return boolean for now, the model wrapper usually handles conversion
203-
# or we check Qwen3DFlashAttention source usage.
204-
# Looking at Qwen3DFlashAttention: `attn_output = attn_fn(..., attention_mask, ...)`
205-
# If using SDPA, it expects boolean or float mask.
206-
# If we look at `modeling_qwen3.py` (standard), it usually employs `_prepare_4d_causal_attention_mask`.
207-
# But here we pass it explicitly.
208-
# To be safe with `eager_attention_forward` and `SDPA`, we typically want:
209-
# 0.0 for unmasked, -inf for masked.
210-
211-
dtype = (
212-
torch.bfloat16
213-
) # or get from device/model, but we return a tensor builder
214-
# We will cast later or return boolean and let logic handle it?
215-
# Safe bet: Return typical extended attention mask format: 0.0 for keep, min_dtype for remove.
216-
217-
# But wait, Qwen3DFlashAttention passes attention_mask directly to attn_fn.
218-
# If attn_fn is SDPA, it handles boolean.
219-
# Let's return a float mask: 0.0 for True, -inf for False.
220-
221-
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
222-
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
223-
224-
return full_mask
225-
226208

227209
def create_dflash_loss_mask(
228210
seq_len: int, block_size: int, device: torch.device
229211
) -> torch.Tensor:
230-
"""
231-
Create DFlash-specific loss mask.
232-
Excludes Block 0 and first position of each block.
233-
"""
212+
"""Create DFlash loss mask: excludes block 0 and first position of each block."""
234213
positions = torch.arange(seq_len, device=device)
235214
block_ids = positions // block_size
236215

0 commit comments

Comments
 (0)