Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/run_qwen3_8b_dflash_online.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
export SPECFORGE_DATA_NUM_PROC=32
NUM_GPUS=${1:-1}

ATTENTION_BACKEND=${2:-flex_attention}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
Expand All @@ -19,6 +21,7 @@ torchrun \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template qwen \
--attention-backend $ATTENTION_BACKEND \
--log-interval 50 \
--save-interval 1000 \
--report-to wandb \
Expand Down
13 changes: 13 additions & 0 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def parse_args():
default=None,
help="MASK token ID. If not provided, auto-detect from tokenizer.",
)
model_group.add_argument(
"--attention-backend",
type=str,
default="flex_attention",
choices=["eager", "sdpa", "flex_attention"],
help="Attention backend for draft model.",
)

dataset_group = parser.add_argument_group("dataset")
dataset_group.add_argument("--train-data-path", type=str, required=True)
Expand Down Expand Up @@ -133,6 +140,10 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
draft_config.num_target_layers = target_config.num_hidden_layers
print_on_rank0("Auto-generated draft config from target model")

# Set attention implementation based on backend
draft_config._attn_implementation = args.attention_backend
print_on_rank0(f"Using attention backend: {args.attention_backend}")

draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16)

# Set capture layers for target model based on draft model config
Expand Down Expand Up @@ -344,6 +355,7 @@ def main():
target_embed_tokens=target_components.embed_tokens,
block_size=draft_model.block_size,
mask_token_id=mask_token_id,
attention_backend=args.attention_backend,
)

dflash_model = FSDP(
Expand Down Expand Up @@ -436,6 +448,7 @@ def main():
{
"loss": f"{loss.item():.4f}",
"acc": f"{accuracy.item():.4f}",
"iter_time": f"{elapsed:.2f}s",
}
)

Expand Down
249 changes: 114 additions & 135 deletions specforge/core/dflash.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
# coding=utf-8
"""DFlash Training Wrapper."""

from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from specforge.modeling.draft.dflash import DFlashDraftModel

try:
from torch.nn.attention.flex_attention import BlockMask, create_block_mask

FLEX_ATTENTION_AVAILABLE = True
except ImportError:
FLEX_ATTENTION_AVAILABLE = False
BlockMask = None
create_block_mask = None


class OnlineDFlashModel(nn.Module):
"""DFlash online training wrapper with block-wise CE loss."""
Expand All @@ -20,17 +29,24 @@ def __init__(
target_embed_tokens: nn.Module,
mask_token_id: int,
block_size: int = 16,
attention_backend: str = "flex_attention",
):
super().__init__()
self.draft_model = draft_model
self.lm_head = target_lm_head
self.embed_tokens = target_embed_tokens
self.block_size = block_size
self.mask_token_id = mask_token_id
self.attention_backend = attention_backend

# Cache for BlockMask
self._cached_block_mask: Optional[BlockMask] = None
self._cached_seq_len: Optional[int] = None
self._cached_bsz: Optional[int] = None
self._cached_num_heads: Optional[int] = None

def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Prepare noise input: first token of each block is real, rest are MASK."""
# input_ids: [bsz, seq_len]
seq_len = input_ids.shape[1]
device = input_ids.device

Expand All @@ -42,88 +58,139 @@ def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor:

return noise_input_ids

def _get_or_create_block_mask(
self, bsz: int, num_heads: int, q_len: int, kv_len: int, device: torch.device
) -> "BlockMask":
"""Get cached BlockMask or create a new one."""
if (
self._cached_block_mask is not None
and self._cached_seq_len == q_len
and self._cached_bsz == bsz
and self._cached_num_heads == num_heads
):
return self._cached_block_mask

block_size = self.block_size

def dflash_mask_fn(b, h, q_idx, kv_idx):
L = q_len
is_ctx = kv_idx < L
q_block = q_idx // block_size
k_block_ctx = kv_idx // block_size
k_block_noise = (kv_idx - L) // block_size
ctx_visible = is_ctx & (k_block_ctx < q_block)
noise_visible = (~is_ctx) & (k_block_noise == q_block)
return ctx_visible | noise_visible

block_mask = create_block_mask(
dflash_mask_fn,
B=bsz,
H=num_heads,
Q_LEN=q_len,
KV_LEN=kv_len,
device=device,
)

self._cached_block_mask = block_mask
self._cached_seq_len = q_len
self._cached_bsz = bsz
self._cached_num_heads = num_heads

return block_mask

def _create_parallel_attention_mask(
self, seq_len: int, device: torch.device
) -> torch.Tensor:
"""
Create [L, 2L] attention mask for parallel training.
- Left half (ctx): Q can see K_ctx if K's block < Q's block
- Right half (noise): Q can see K_noise if same block (bidirectional)
"""
indices = torch.arange(seq_len, device=device)
block_ids = indices // self.block_size

q_block_ids = block_ids.unsqueeze(1)
k_block_ids = block_ids.unsqueeze(0)

ctx_mask = k_block_ids < q_block_ids
noise_mask = q_block_ids == k_block_ids

full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1)
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)

return full_mask

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parallel Block-wise training forward pass.
Uses a global attention mask to process all blocks in parallel without
modifying the underlying modeling code.
"""
"""Parallel block-wise training forward pass."""
bsz, seq_len = input_ids.shape
device = input_ids.device

# Truncate to multiple of block_size
n_blocks = seq_len // self.block_size
effective_len = n_blocks * self.block_size
input_ids = input_ids[:, :effective_len]
# hidden_states here is the RAW target hidden states (before projection)
hidden_states = hidden_states[:, :effective_len, :]
loss_mask = loss_mask[:, :effective_len]
# Original attention mask is typically just 1s for valid tokens
attention_mask = attention_mask[:, :effective_len]

# 2. Prepare Inputs
# Prepare inputs
noise_input_ids = self.prepare_noise_input(input_ids)
noise_embedding = self.embed_tokens(noise_input_ids)

# 3. Construct Parallel Training Position IDs
# We need Position IDs for K which has length 2*L (Context + Noise)
# Context part: 0..L-1
# Noise part: 0..L-1
# This ensures that Noise at pos i uses the same RoPE embedding as Context at pos i
# Position IDs: [ctx_pos, noise_pos] both 0..L-1
pos_seq = torch.arange(effective_len, device=device)
# shape: [1, 2*L] -> [bsz, 2*L]
position_ids = torch.cat([pos_seq, pos_seq], dim=0).unsqueeze(0).expand(bsz, -1)

# 4. Construct Parallel Attention Mask
# The modeling code will internally concat K = [K_ctx, K_noise]
# So K has length 2*L. Q has length L (from Noise).
# We need a mask of shape [L, 2*L]
dflash_attn_mask = self._create_parallel_attention_mask(effective_len, device)
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
# Expand to batch size: [bsz, 1, L, 2*L]
# Note: transformers usually expects [bsz, 1, Q, K]
dflash_attn_mask = (
dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1)
)

# 5. Parallel Forward Pass
# efficient: one single forward pass for the whole sequence
# Construct attention mask
if (
self.attention_backend == "flex_attention"
and FLEX_ATTENTION_AVAILABLE
and create_block_mask is not None
):
num_heads = self.draft_model.config.num_attention_heads
dflash_attn_mask = self._get_or_create_block_mask(
bsz=bsz,
num_heads=num_heads,
q_len=effective_len,
kv_len=effective_len * 2,
device=device,
)
else:
dflash_attn_mask = self._create_parallel_attention_mask(
effective_len, device
)
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
dflash_attn_mask = (
dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1)
)

# Forward pass
hidden = self.draft_model(
position_ids=position_ids, # [bsz, 2*L] (used for RoPE)
noise_embedding=noise_embedding, # [bsz, L, H] (Query source)
target_hidden=hidden_states, # [bsz, L, H] (Context source)
attention_mask=dflash_attn_mask, # [bsz, 1, L, 2*L]
position_ids=position_ids,
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
)

# 6. Compute Loss
# Create mask for valid loss positions (skip block 0, skip block starts)
# Compute loss (skip block 0 and block starts)
dflash_loss_mask_base = create_dflash_loss_mask(
effective_len, self.block_size, device
)
combined_mask = loss_mask * dflash_loss_mask_base.unsqueeze(0)

# hidden[i] predicts input_ids[i] (based on DFlash design where noise[i] is input to predict target[i])
# However, check design:
# "hidden[i] predicts token[i] (directly corresponding), not token[i+1]"
# "noise_input[i] is MASK, we want to predict input_ids[i]"
# So logits at index i should be compared to labels at index i.

logits = self.lm_head(hidden)

# Calculate Loss
# Flatten
logits_flat = logits.reshape(-1, logits.size(-1))
labels_flat = input_ids.reshape(-1)
mask_flat = combined_mask.reshape(-1)

# Optimization: only compute CE on valid tokens
active_indices = mask_flat > 0.5
active_logits = logits_flat[active_indices]
active_labels = labels_flat[active_indices]
Expand All @@ -138,99 +205,11 @@ def forward(

return loss, accuracy

def _create_parallel_attention_mask(
self, seq_len: int, device: torch.device
) -> torch.Tensor:
"""
Creates the [L, 2L] mask for parallel training.
Rows: Query (Noise) indices 0..L-1
Cols: Key indices 0..2L-1 (First L are Context, Next L are Noise)
Logic for Query at index i (belonging to block B = i // block_size):
1. Can see Context (Cols 0..L-1):
- Can see context of PREVIOUS blocks.
- Range: [0, B * block_size)
2. Can see Noise (Cols L..2L-1):
- Can see noise of CURRENT block up to self.
- Range: [L + B * block_size, L + i]
- (Inclusive of i because causal mask usually allows seeing self)
"""
# Block indices for each position [0, 0, ..., 1, 1, ...]
indices = torch.arange(seq_len, device=device)
block_ids = indices // self.block_size

# 1. Context Mask (L x L) - Left half of K
# Q[i] can see K_ctx[j] if Block(Q[i]) > Block(K_ctx[j])
# Actually, Block(Q[i]) can see Context of all previous blocks.
# It implies Block(K_ctx[j]) < Block(Q[i])
# Wait, strictly: Block B needs context from 0..(B*16).
# So it sees all K_ctx where index < B * 16.
# Which is equivalent to block_ids[j] < block_ids[i].

# Broadcast logic
# shape [L, 1]
q_block_ids = block_ids.unsqueeze(1)
# shape [1, L]
k_block_ids = block_ids.unsqueeze(0)

# Mask: 1 if K's block is strictly less than Q's block
# This gives access to all PREVIOUS blocks' context.
ctx_mask = k_block_ids < q_block_ids

# 2. Noise Mask (L x L) - Right half of K
# Standard Causal Mask WITHIN the same block.
# Q[i] can see K_noise[j] if:
# a) Same Block: Block(Q[i]) == Block(K_noise[j])
# b) Causal: j <= i
# Different blocks cannot see each other's noise.

same_block = q_block_ids == k_block_ids

noise_mask = same_block

# Combine [Ctx_Mask, Noise_Mask]
# Shape [L, 2L]
# We need float mask for attention: 0.0 for allow, -inf for mask
# Transformers usually handles boolean masks by converting them,
# but explicit MinValue is safer if passing to generic attention.
# However, most HF models accept boolean [batch, 1, Q, K].

full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1)

# Check standard HF format: usually 1 for keep, 0 for mask, or boolean.
# Qwen3 implementation likely uses typical attention_mask handling.
# Let's return boolean for now, the model wrapper usually handles conversion
# or we check Qwen3DFlashAttention source usage.
# Looking at Qwen3DFlashAttention: `attn_output = attn_fn(..., attention_mask, ...)`
# If using SDPA, it expects boolean or float mask.
# If we look at `modeling_qwen3.py` (standard), it usually employs `_prepare_4d_causal_attention_mask`.
# But here we pass it explicitly.
# To be safe with `eager_attention_forward` and `SDPA`, we typically want:
# 0.0 for unmasked, -inf for masked.

dtype = (
torch.bfloat16
) # or get from device/model, but we return a tensor builder
# We will cast later or return boolean and let logic handle it?
# Safe bet: Return typical extended attention mask format: 0.0 for keep, min_dtype for remove.

# But wait, Qwen3DFlashAttention passes attention_mask directly to attn_fn.
# If attn_fn is SDPA, it handles boolean.
# Let's return a float mask: 0.0 for True, -inf for False.

full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)

return full_mask


def create_dflash_loss_mask(
seq_len: int, block_size: int, device: torch.device
) -> torch.Tensor:
"""
Create DFlash-specific loss mask.
Excludes Block 0 and first position of each block.
"""
"""Create DFlash loss mask: excludes block 0 and first position of each block."""
positions = torch.arange(seq_len, device=device)
block_ids = positions // block_size

Expand Down