Skip to content

Commit 32ee970

Browse files
committed
fix(dflash): remove incorrect padding shift and fix cross-block data leak
- Remove Eagle3-inherited padding (left-shift) calls in SGLang backend that misaligned input_ids/hidden_states with loss_mask. DFlash uses same-position prediction and does not need this shift. - Fix cross-block data leak in random-anchor mode by changing context visibility from block-id comparison to original-position comparison, preventing overlapping blocks from leaking future hidden states.
1 parent f4a4a5b commit 32ee970

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

specforge/core/dflash.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def _get_or_create_block_mask(
151151
kv_len: int,
152152
device: torch.device,
153153
block_ids: Optional[torch.Tensor] = None,
154+
orig_positions: Optional[torch.Tensor] = None,
155+
token_anchor_pos: Optional[torch.Tensor] = None,
154156
) -> "BlockMask":
155157
"""Get cached BlockMask or create a new one."""
156158
if block_ids is None:
@@ -165,17 +167,21 @@ def _get_or_create_block_mask(
165167

166168
if block_ids is not None:
167169
_block_ids = block_ids
170+
_orig_pos = orig_positions
171+
_anchor_pos = token_anchor_pos
168172

169173
def dflash_mask_fn(b, h, q_idx, kv_idx):
170174
L = q_len
171175
is_ctx = kv_idx < L
172176
q_b = _block_ids[b, q_idx]
173-
k_ctx = _block_ids[b, kv_idx.clamp(max=L - 1)]
177+
k_ctx_id = _block_ids[b, kv_idx.clamp(max=L - 1)]
174178
k_noise = _block_ids[b, (kv_idx - L).clamp(min=0, max=L - 1)]
175179
q_valid = q_b >= 0
176-
k_ctx_valid = k_ctx >= 0
180+
k_ctx_valid = k_ctx_id >= 0
177181
k_noise_valid = k_noise >= 0
178-
ctx_visible = is_ctx & q_valid & k_ctx_valid & (k_ctx < q_b)
182+
kv_orig = _orig_pos[b, kv_idx.clamp(max=L - 1)]
183+
q_anchor = _anchor_pos[b, q_idx]
184+
ctx_visible = is_ctx & q_valid & k_ctx_valid & (kv_orig < q_anchor)
179185
noise_visible = (~is_ctx) & q_valid & k_noise_valid & (k_noise == q_b)
180186
return ctx_visible | noise_visible
181187

@@ -213,6 +219,8 @@ def _create_parallel_attention_mask(
213219
seq_len: int,
214220
device: torch.device,
215221
block_ids: Optional[torch.Tensor] = None,
222+
orig_positions: Optional[torch.Tensor] = None,
223+
token_anchor_pos: Optional[torch.Tensor] = None,
216224
) -> torch.Tensor:
217225
"""Create [bsz, L, 2L] attention mask for parallel training."""
218226
if block_ids is None:
@@ -226,11 +234,13 @@ def _create_parallel_attention_mask(
226234
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
227235
return full_mask.unsqueeze(0).expand(bsz, -1, -1)
228236

237+
q_anchor = token_anchor_pos.unsqueeze(2)
238+
k_orig = orig_positions.unsqueeze(1)
229239
q_ids = block_ids.unsqueeze(2)
230240
k_ids = block_ids.unsqueeze(1)
231241
q_valid = q_ids >= 0
232242
k_valid = k_ids >= 0
233-
ctx_mask = q_valid & k_valid & (k_ids < q_ids)
243+
ctx_mask = q_valid & k_valid & (k_orig < q_anchor)
234244
noise_mask = q_valid & k_valid & (k_ids == q_ids)
235245
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=2)
236246
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
@@ -248,6 +258,8 @@ def forward(
248258
bsz, seq_len = input_ids.shape
249259
device = input_ids.device
250260
block_ids = None
261+
orig_positions = None
262+
token_anchor_pos = None
251263

252264
if self.random_anchor and self.training:
253265
anchor_positions, block_keep_mask = self._sample_anchor_positions(
@@ -264,6 +276,10 @@ def forward(
264276
)
265277
effective_len = input_ids.shape[1]
266278
base_positions = block_positions
279+
orig_positions = block_positions
280+
token_anchor_pos = anchor_positions.repeat_interleave(
281+
self.block_size, dim=1
282+
)
267283
else:
268284
n_blocks = seq_len // self.block_size
269285
effective_len = n_blocks * self.block_size
@@ -291,10 +307,12 @@ def forward(
291307
kv_len=effective_len * 2,
292308
device=device,
293309
block_ids=block_ids,
310+
orig_positions=orig_positions,
311+
token_anchor_pos=token_anchor_pos,
294312
)
295313
else:
296314
dflash_attn_mask = self._create_parallel_attention_mask(
297-
bsz, effective_len, device, block_ids
315+
bsz, effective_len, device, block_ids, orig_positions, token_anchor_pos
298316
)
299317
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
300318
dflash_attn_mask = dflash_attn_mask.unsqueeze(1)

specforge/modeling/target/dflash_target_model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from transformers import AutoModelForCausalLM
1919

2020
from specforge.distributed import get_tp_group
21-
from specforge.utils import padding
2221

2322
from .sglang_backend import SGLangRunner
2423

@@ -235,10 +234,6 @@ def generate_dflash_data(
235234
attention_mask = torch.cat([d[1] for d in data_cache], dim=0)
236235
loss_mask = torch.cat([d[2] for d in data_cache], dim=0)
237236

238-
# Padding might be needed if batching varied lengths (but usually fixed length training)
239-
hidden_states = padding(hidden_states, left=False)
240-
input_ids = padding(input_ids, left=False)
241-
242237
return DFlashTargetOutput(
243238
hidden_states=hidden_states,
244239
input_ids=input_ids,

0 commit comments

Comments
 (0)