11# coding=utf-8
22"""DFlash Training Wrapper."""
33
4- from typing import Tuple
4+ from typing import Optional , Tuple
55
66import torch
77import torch .nn as nn
88import torch .nn .functional as F
99
1010from specforge .modeling .draft .dflash import DFlashDraftModel
1111
12+ try :
13+ from torch .nn .attention .flex_attention import BlockMask , create_block_mask
14+
15+ FLEX_ATTENTION_AVAILABLE = True
16+ except ImportError :
17+ FLEX_ATTENTION_AVAILABLE = False
18+ BlockMask = None
19+ create_block_mask = None
20+
1221
1322class OnlineDFlashModel (nn .Module ):
1423 """DFlash online training wrapper with block-wise CE loss."""
@@ -20,17 +29,24 @@ def __init__(
2029 target_embed_tokens : nn .Module ,
2130 mask_token_id : int ,
2231 block_size : int = 16 ,
32+ attention_backend : str = "flex_attention" ,
2333 ):
2434 super ().__init__ ()
2535 self .draft_model = draft_model
2636 self .lm_head = target_lm_head
2737 self .embed_tokens = target_embed_tokens
2838 self .block_size = block_size
2939 self .mask_token_id = mask_token_id
40+ self .attention_backend = attention_backend
41+
42+ # Cache for BlockMask
43+ self ._cached_block_mask : Optional [BlockMask ] = None
44+ self ._cached_seq_len : Optional [int ] = None
45+ self ._cached_bsz : Optional [int ] = None
46+ self ._cached_num_heads : Optional [int ] = None
3047
3148 def prepare_noise_input (self , input_ids : torch .Tensor ) -> torch .Tensor :
3249 """Prepare noise input: first token of each block is real, rest are MASK."""
33- # input_ids: [bsz, seq_len]
3450 seq_len = input_ids .shape [1 ]
3551 device = input_ids .device
3652
@@ -42,88 +58,139 @@ def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor:
4258
4359 return noise_input_ids
4460
61+ def _get_or_create_block_mask (
62+ self , bsz : int , num_heads : int , q_len : int , kv_len : int , device : torch .device
63+ ) -> "BlockMask" :
64+ """Get cached BlockMask or create a new one."""
65+ if (
66+ self ._cached_block_mask is not None
67+ and self ._cached_seq_len == q_len
68+ and self ._cached_bsz == bsz
69+ and self ._cached_num_heads == num_heads
70+ ):
71+ return self ._cached_block_mask
72+
73+ block_size = self .block_size
74+
75+ def dflash_mask_fn (b , h , q_idx , kv_idx ):
76+ L = q_len
77+ is_ctx = kv_idx < L
78+ q_block = q_idx // block_size
79+ k_block_ctx = kv_idx // block_size
80+ k_block_noise = (kv_idx - L ) // block_size
81+ ctx_visible = is_ctx & (k_block_ctx < q_block )
82+ noise_visible = (~ is_ctx ) & (k_block_noise == q_block )
83+ return ctx_visible | noise_visible
84+
85+ block_mask = create_block_mask (
86+ dflash_mask_fn ,
87+ B = bsz ,
88+ H = num_heads ,
89+ Q_LEN = q_len ,
90+ KV_LEN = kv_len ,
91+ device = device ,
92+ )
93+
94+ self ._cached_block_mask = block_mask
95+ self ._cached_seq_len = q_len
96+ self ._cached_bsz = bsz
97+ self ._cached_num_heads = num_heads
98+
99+ return block_mask
100+
101+ def _create_parallel_attention_mask (
102+ self , seq_len : int , device : torch .device
103+ ) -> torch .Tensor :
104+ """
105+ Create [L, 2L] attention mask for parallel training.
106+ - Left half (ctx): Q can see K_ctx if K's block < Q's block
107+ - Right half (noise): Q can see K_noise if same block (bidirectional)
108+ """
109+ indices = torch .arange (seq_len , device = device )
110+ block_ids = indices // self .block_size
111+
112+ q_block_ids = block_ids .unsqueeze (1 )
113+ k_block_ids = block_ids .unsqueeze (0 )
114+
115+ ctx_mask = k_block_ids < q_block_ids
116+ noise_mask = q_block_ids == k_block_ids
117+
118+ full_mask_bool = torch .cat ([ctx_mask , noise_mask ], dim = 1 )
119+ full_mask = torch .zeros_like (full_mask_bool , dtype = torch .float32 )
120+ full_mask .masked_fill_ (~ full_mask_bool , torch .finfo (torch .float32 ).min )
121+
122+ return full_mask
123+
45124 def forward (
46125 self ,
47126 input_ids : torch .Tensor ,
48127 attention_mask : torch .Tensor ,
49128 hidden_states : torch .Tensor ,
50129 loss_mask : torch .Tensor ,
51130 ) -> Tuple [torch .Tensor , torch .Tensor ]:
52- """
53- Parallel Block-wise training forward pass.
54-
55- Uses a global attention mask to process all blocks in parallel without
56- modifying the underlying modeling code.
57- """
131+ """Parallel block-wise training forward pass."""
58132 bsz , seq_len = input_ids .shape
59133 device = input_ids .device
60134
61135 # Truncate to multiple of block_size
62136 n_blocks = seq_len // self .block_size
63137 effective_len = n_blocks * self .block_size
64138 input_ids = input_ids [:, :effective_len ]
65- # hidden_states here is the RAW target hidden states (before projection)
66139 hidden_states = hidden_states [:, :effective_len , :]
67140 loss_mask = loss_mask [:, :effective_len ]
68- # Original attention mask is typically just 1s for valid tokens
69141 attention_mask = attention_mask [:, :effective_len ]
70142
71- # 2. Prepare Inputs
143+ # Prepare inputs
72144 noise_input_ids = self .prepare_noise_input (input_ids )
73145 noise_embedding = self .embed_tokens (noise_input_ids )
74146
75- # 3. Construct Parallel Training Position IDs
76- # We need Position IDs for K which has length 2*L (Context + Noise)
77- # Context part: 0..L-1
78- # Noise part: 0..L-1
79- # This ensures that Noise at pos i uses the same RoPE embedding as Context at pos i
147+ # Position IDs: [ctx_pos, noise_pos] both 0..L-1
80148 pos_seq = torch .arange (effective_len , device = device )
81- # shape: [1, 2*L] -> [bsz, 2*L]
82149 position_ids = torch .cat ([pos_seq , pos_seq ], dim = 0 ).unsqueeze (0 ).expand (bsz , - 1 )
83150
84- # 4. Construct Parallel Attention Mask
85- # The modeling code will internally concat K = [K_ctx, K_noise]
86- # So K has length 2*L. Q has length L (from Noise).
87- # We need a mask of shape [L, 2*L]
88- dflash_attn_mask = self ._create_parallel_attention_mask (effective_len , device )
89- dflash_attn_mask = dflash_attn_mask .to (dtype = hidden_states .dtype )
90- # Expand to batch size: [bsz, 1, L, 2*L]
91- # Note: transformers usually expects [bsz, 1, Q, K]
92- dflash_attn_mask = (
93- dflash_attn_mask .unsqueeze (0 ).unsqueeze (0 ).expand (bsz , - 1 , - 1 , - 1 )
94- )
95-
96- # 5. Parallel Forward Pass
97- # efficient: one single forward pass for the whole sequence
151+ # Construct attention mask
152+ if (
153+ self .attention_backend == "flex_attention"
154+ and FLEX_ATTENTION_AVAILABLE
155+ and create_block_mask is not None
156+ ):
157+ num_heads = self .draft_model .config .num_attention_heads
158+ dflash_attn_mask = self ._get_or_create_block_mask (
159+ bsz = bsz ,
160+ num_heads = num_heads ,
161+ q_len = effective_len ,
162+ kv_len = effective_len * 2 ,
163+ device = device ,
164+ )
165+ else :
166+ dflash_attn_mask = self ._create_parallel_attention_mask (
167+ effective_len , device
168+ )
169+ dflash_attn_mask = dflash_attn_mask .to (dtype = hidden_states .dtype )
170+ dflash_attn_mask = (
171+ dflash_attn_mask .unsqueeze (0 ).unsqueeze (0 ).expand (bsz , - 1 , - 1 , - 1 )
172+ )
173+
174+ # Forward pass
98175 hidden = self .draft_model (
99- position_ids = position_ids , # [bsz, 2*L] (used for RoPE)
100- noise_embedding = noise_embedding , # [bsz, L, H] (Query source)
101- target_hidden = hidden_states , # [bsz, L, H] (Context source)
102- attention_mask = dflash_attn_mask , # [bsz, 1, L, 2*L]
176+ position_ids = position_ids ,
177+ noise_embedding = noise_embedding ,
178+ target_hidden = hidden_states ,
179+ attention_mask = dflash_attn_mask ,
103180 )
104181
105- # 6. Compute Loss
106- # Create mask for valid loss positions (skip block 0, skip block starts)
182+ # Compute loss (skip block 0 and block starts)
107183 dflash_loss_mask_base = create_dflash_loss_mask (
108184 effective_len , self .block_size , device
109185 )
110186 combined_mask = loss_mask * dflash_loss_mask_base .unsqueeze (0 )
111187
112- # hidden[i] predicts input_ids[i] (based on DFlash design where noise[i] is input to predict target[i])
113- # However, check design:
114- # "hidden[i] predicts token[i] (directly corresponding), not token[i+1]"
115- # "noise_input[i] is MASK, we want to predict input_ids[i]"
116- # So logits at index i should be compared to labels at index i.
117-
118188 logits = self .lm_head (hidden )
119189
120- # Calculate Loss
121- # Flatten
122190 logits_flat = logits .reshape (- 1 , logits .size (- 1 ))
123191 labels_flat = input_ids .reshape (- 1 )
124192 mask_flat = combined_mask .reshape (- 1 )
125193
126- # Optimization: only compute CE on valid tokens
127194 active_indices = mask_flat > 0.5
128195 active_logits = logits_flat [active_indices ]
129196 active_labels = labels_flat [active_indices ]
@@ -138,99 +205,11 @@ def forward(
138205
139206 return loss , accuracy
140207
141- def _create_parallel_attention_mask (
142- self , seq_len : int , device : torch .device
143- ) -> torch .Tensor :
144- """
145- Creates the [L, 2L] mask for parallel training.
146- Rows: Query (Noise) indices 0..L-1
147- Cols: Key indices 0..2L-1 (First L are Context, Next L are Noise)
148-
149- Logic for Query at index i (belonging to block B = i // block_size):
150- 1. Can see Context (Cols 0..L-1):
151- - Can see context of PREVIOUS blocks.
152- - Range: [0, B * block_size)
153- 2. Can see Noise (Cols L..2L-1):
154- - Can see noise of CURRENT block up to self.
155- - Range: [L + B * block_size, L + i]
156- - (Inclusive of i because causal mask usually allows seeing self)
157- """
158- # Block indices for each position [0, 0, ..., 1, 1, ...]
159- indices = torch .arange (seq_len , device = device )
160- block_ids = indices // self .block_size
161-
162- # 1. Context Mask (L x L) - Left half of K
163- # Q[i] can see K_ctx[j] if Block(Q[i]) > Block(K_ctx[j])
164- # Actually, Block(Q[i]) can see Context of all previous blocks.
165- # It implies Block(K_ctx[j]) < Block(Q[i])
166- # Wait, strictly: Block B needs context from 0..(B*16).
167- # So it sees all K_ctx where index < B * 16.
168- # Which is equivalent to block_ids[j] < block_ids[i].
169-
170- # Broadcast logic
171- # shape [L, 1]
172- q_block_ids = block_ids .unsqueeze (1 )
173- # shape [1, L]
174- k_block_ids = block_ids .unsqueeze (0 )
175-
176- # Mask: 1 if K's block is strictly less than Q's block
177- # This gives access to all PREVIOUS blocks' context.
178- ctx_mask = k_block_ids < q_block_ids
179-
180- # 2. Noise Mask (L x L) - Right half of K
181- # Standard Causal Mask WITHIN the same block.
182- # Q[i] can see K_noise[j] if:
183- # a) Same Block: Block(Q[i]) == Block(K_noise[j])
184- # b) Causal: j <= i
185- # Different blocks cannot see each other's noise.
186-
187- same_block = q_block_ids == k_block_ids
188-
189- noise_mask = same_block
190-
191- # Combine [Ctx_Mask, Noise_Mask]
192- # Shape [L, 2L]
193- # We need float mask for attention: 0.0 for allow, -inf for mask
194- # Transformers usually handles boolean masks by converting them,
195- # but explicit MinValue is safer if passing to generic attention.
196- # However, most HF models accept boolean [batch, 1, Q, K].
197-
198- full_mask_bool = torch .cat ([ctx_mask , noise_mask ], dim = 1 )
199-
200- # Check standard HF format: usually 1 for keep, 0 for mask, or boolean.
201- # Qwen3 implementation likely uses typical attention_mask handling.
202- # Let's return boolean for now, the model wrapper usually handles conversion
203- # or we check Qwen3DFlashAttention source usage.
204- # Looking at Qwen3DFlashAttention: `attn_output = attn_fn(..., attention_mask, ...)`
205- # If using SDPA, it expects boolean or float mask.
206- # If we look at `modeling_qwen3.py` (standard), it usually employs `_prepare_4d_causal_attention_mask`.
207- # But here we pass it explicitly.
208- # To be safe with `eager_attention_forward` and `SDPA`, we typically want:
209- # 0.0 for unmasked, -inf for masked.
210-
211- dtype = (
212- torch .bfloat16
213- ) # or get from device/model, but we return a tensor builder
214- # We will cast later or return boolean and let logic handle it?
215- # Safe bet: Return typical extended attention mask format: 0.0 for keep, min_dtype for remove.
216-
217- # But wait, Qwen3DFlashAttention passes attention_mask directly to attn_fn.
218- # If attn_fn is SDPA, it handles boolean.
219- # Let's return a float mask: 0.0 for True, -inf for False.
220-
221- full_mask = torch .zeros_like (full_mask_bool , dtype = torch .float32 )
222- full_mask .masked_fill_ (~ full_mask_bool , torch .finfo (torch .float32 ).min )
223-
224- return full_mask
225-
226208
227209def create_dflash_loss_mask (
228210 seq_len : int , block_size : int , device : torch .device
229211) -> torch .Tensor :
230- """
231- Create DFlash-specific loss mask.
232- Excludes Block 0 and first position of each block.
233- """
212+ """Create DFlash loss mask: excludes block 0 and first position of each block."""
234213 positions = torch .arange (seq_len , device = device )
235214 block_ids = positions // block_size
236215
0 commit comments