diff --git a/examples/run_qwen3_8b_dflash_online.sh b/examples/run_qwen3_8b_dflash_online.sh index 944fad403..3457a8e8a 100755 --- a/examples/run_qwen3_8b_dflash_online.sh +++ b/examples/run_qwen3_8b_dflash_online.sh @@ -6,6 +6,8 @@ export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels export SPECFORGE_DATA_NUM_PROC=32 NUM_GPUS=${1:-1} +ATTENTION_BACKEND=${2:-flex_attention} + torchrun \ --standalone \ --nproc_per_node $NUM_GPUS \ @@ -19,6 +21,7 @@ torchrun \ --learning-rate 1e-4 \ --max-length 2048 \ --chat-template qwen \ + --attention-backend $ATTENTION_BACKEND \ --log-interval 50 \ --save-interval 1000 \ --report-to wandb \ diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 968e63cfb..bdd79d731 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -57,6 +57,13 @@ def parse_args(): default=None, help="MASK token ID. If not provided, auto-detect from tokenizer.", ) + model_group.add_argument( + "--attention-backend", + type=str, + default="flex_attention", + choices=["eager", "sdpa", "flex_attention"], + help="Attention backend for draft model.", + ) dataset_group = parser.add_argument_group("dataset") dataset_group.add_argument("--train-data-path", type=str, required=True) @@ -133,6 +140,10 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: draft_config.num_target_layers = target_config.num_hidden_layers print_on_rank0("Auto-generated draft config from target model") + # Set attention implementation based on backend + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) # Set capture layers for target model based on draft model config @@ -344,6 +355,7 @@ def main(): target_embed_tokens=target_components.embed_tokens, block_size=draft_model.block_size, mask_token_id=mask_token_id, + attention_backend=args.attention_backend, ) dflash_model = FSDP( @@ -436,6 +448,7 @@ def main(): { "loss": f"{loss.item():.4f}", "acc": f"{accuracy.item():.4f}", + "iter_time": f"{elapsed:.2f}s", } ) diff --git a/specforge/core/dflash.py b/specforge/core/dflash.py index 002c4dc62..dd5d80147 100644 --- a/specforge/core/dflash.py +++ b/specforge/core/dflash.py @@ -1,7 +1,7 @@ # coding=utf-8 """DFlash Training Wrapper.""" -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -9,6 +9,15 @@ from specforge.modeling.draft.dflash import DFlashDraftModel +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + + FLEX_ATTENTION_AVAILABLE = True +except ImportError: + FLEX_ATTENTION_AVAILABLE = False + BlockMask = None + create_block_mask = None + class OnlineDFlashModel(nn.Module): """DFlash online training wrapper with block-wise CE loss.""" @@ -20,6 +29,7 @@ def __init__( target_embed_tokens: nn.Module, mask_token_id: int, block_size: int = 16, + attention_backend: str = "flex_attention", ): super().__init__() self.draft_model = draft_model @@ -27,10 +37,16 @@ def __init__( self.embed_tokens = target_embed_tokens self.block_size = block_size self.mask_token_id = mask_token_id + self.attention_backend = attention_backend + + # Cache for BlockMask + self._cached_block_mask: Optional[BlockMask] = None + self._cached_seq_len: Optional[int] = None + self._cached_bsz: Optional[int] = None + self._cached_num_heads: Optional[int] = None def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor: """Prepare noise input: first token of each block is real, rest are MASK.""" - # input_ids: [bsz, seq_len] seq_len = input_ids.shape[1] device = input_ids.device @@ -42,6 +58,69 @@ def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor: return noise_input_ids + def _get_or_create_block_mask( + self, bsz: int, num_heads: int, q_len: int, kv_len: int, device: torch.device + ) -> "BlockMask": + """Get cached BlockMask or create a new one.""" + if ( + self._cached_block_mask is not None + and self._cached_seq_len == q_len + and self._cached_bsz == bsz + and self._cached_num_heads == num_heads + ): + return self._cached_block_mask + + block_size = self.block_size + + def dflash_mask_fn(b, h, q_idx, kv_idx): + L = q_len + is_ctx = kv_idx < L + q_block = q_idx // block_size + k_block_ctx = kv_idx // block_size + k_block_noise = (kv_idx - L) // block_size + ctx_visible = is_ctx & (k_block_ctx < q_block) + noise_visible = (~is_ctx) & (k_block_noise == q_block) + return ctx_visible | noise_visible + + block_mask = create_block_mask( + dflash_mask_fn, + B=bsz, + H=num_heads, + Q_LEN=q_len, + KV_LEN=kv_len, + device=device, + ) + + self._cached_block_mask = block_mask + self._cached_seq_len = q_len + self._cached_bsz = bsz + self._cached_num_heads = num_heads + + return block_mask + + def _create_parallel_attention_mask( + self, seq_len: int, device: torch.device + ) -> torch.Tensor: + """ + Create [L, 2L] attention mask for parallel training. + - Left half (ctx): Q can see K_ctx if K's block < Q's block + - Right half (noise): Q can see K_noise if same block (bidirectional) + """ + indices = torch.arange(seq_len, device=device) + block_ids = indices // self.block_size + + q_block_ids = block_ids.unsqueeze(1) + k_block_ids = block_ids.unsqueeze(0) + + ctx_mask = k_block_ids < q_block_ids + noise_mask = q_block_ids == k_block_ids + + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) + full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) + + return full_mask + def forward( self, input_ids: torch.Tensor, @@ -49,12 +128,7 @@ def forward( hidden_states: torch.Tensor, loss_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Parallel Block-wise training forward pass. - - Uses a global attention mask to process all blocks in parallel without - modifying the underlying modeling code. - """ + """Parallel block-wise training forward pass.""" bsz, seq_len = input_ids.shape device = input_ids.device @@ -62,68 +136,61 @@ def forward( n_blocks = seq_len // self.block_size effective_len = n_blocks * self.block_size input_ids = input_ids[:, :effective_len] - # hidden_states here is the RAW target hidden states (before projection) hidden_states = hidden_states[:, :effective_len, :] loss_mask = loss_mask[:, :effective_len] - # Original attention mask is typically just 1s for valid tokens attention_mask = attention_mask[:, :effective_len] - # 2. Prepare Inputs + # Prepare inputs noise_input_ids = self.prepare_noise_input(input_ids) noise_embedding = self.embed_tokens(noise_input_ids) - # 3. Construct Parallel Training Position IDs - # We need Position IDs for K which has length 2*L (Context + Noise) - # Context part: 0..L-1 - # Noise part: 0..L-1 - # This ensures that Noise at pos i uses the same RoPE embedding as Context at pos i + # Position IDs: [ctx_pos, noise_pos] both 0..L-1 pos_seq = torch.arange(effective_len, device=device) - # shape: [1, 2*L] -> [bsz, 2*L] position_ids = torch.cat([pos_seq, pos_seq], dim=0).unsqueeze(0).expand(bsz, -1) - # 4. Construct Parallel Attention Mask - # The modeling code will internally concat K = [K_ctx, K_noise] - # So K has length 2*L. Q has length L (from Noise). - # We need a mask of shape [L, 2*L] - dflash_attn_mask = self._create_parallel_attention_mask(effective_len, device) - dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype) - # Expand to batch size: [bsz, 1, L, 2*L] - # Note: transformers usually expects [bsz, 1, Q, K] - dflash_attn_mask = ( - dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1) - ) - - # 5. Parallel Forward Pass - # efficient: one single forward pass for the whole sequence + # Construct attention mask + if ( + self.attention_backend == "flex_attention" + and FLEX_ATTENTION_AVAILABLE + and create_block_mask is not None + ): + num_heads = self.draft_model.config.num_attention_heads + dflash_attn_mask = self._get_or_create_block_mask( + bsz=bsz, + num_heads=num_heads, + q_len=effective_len, + kv_len=effective_len * 2, + device=device, + ) + else: + dflash_attn_mask = self._create_parallel_attention_mask( + effective_len, device + ) + dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype) + dflash_attn_mask = ( + dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1) + ) + + # Forward pass hidden = self.draft_model( - position_ids=position_ids, # [bsz, 2*L] (used for RoPE) - noise_embedding=noise_embedding, # [bsz, L, H] (Query source) - target_hidden=hidden_states, # [bsz, L, H] (Context source) - attention_mask=dflash_attn_mask, # [bsz, 1, L, 2*L] + position_ids=position_ids, + noise_embedding=noise_embedding, + target_hidden=hidden_states, + attention_mask=dflash_attn_mask, ) - # 6. Compute Loss - # Create mask for valid loss positions (skip block 0, skip block starts) + # Compute loss (skip block 0 and block starts) dflash_loss_mask_base = create_dflash_loss_mask( effective_len, self.block_size, device ) combined_mask = loss_mask * dflash_loss_mask_base.unsqueeze(0) - # hidden[i] predicts input_ids[i] (based on DFlash design where noise[i] is input to predict target[i]) - # However, check design: - # "hidden[i] predicts token[i] (directly corresponding), not token[i+1]" - # "noise_input[i] is MASK, we want to predict input_ids[i]" - # So logits at index i should be compared to labels at index i. - logits = self.lm_head(hidden) - # Calculate Loss - # Flatten logits_flat = logits.reshape(-1, logits.size(-1)) labels_flat = input_ids.reshape(-1) mask_flat = combined_mask.reshape(-1) - # Optimization: only compute CE on valid tokens active_indices = mask_flat > 0.5 active_logits = logits_flat[active_indices] active_labels = labels_flat[active_indices] @@ -138,99 +205,11 @@ def forward( return loss, accuracy - def _create_parallel_attention_mask( - self, seq_len: int, device: torch.device - ) -> torch.Tensor: - """ - Creates the [L, 2L] mask for parallel training. - Rows: Query (Noise) indices 0..L-1 - Cols: Key indices 0..2L-1 (First L are Context, Next L are Noise) - - Logic for Query at index i (belonging to block B = i // block_size): - 1. Can see Context (Cols 0..L-1): - - Can see context of PREVIOUS blocks. - - Range: [0, B * block_size) - 2. Can see Noise (Cols L..2L-1): - - Can see noise of CURRENT block up to self. - - Range: [L + B * block_size, L + i] - - (Inclusive of i because causal mask usually allows seeing self) - """ - # Block indices for each position [0, 0, ..., 1, 1, ...] - indices = torch.arange(seq_len, device=device) - block_ids = indices // self.block_size - - # 1. Context Mask (L x L) - Left half of K - # Q[i] can see K_ctx[j] if Block(Q[i]) > Block(K_ctx[j]) - # Actually, Block(Q[i]) can see Context of all previous blocks. - # It implies Block(K_ctx[j]) < Block(Q[i]) - # Wait, strictly: Block B needs context from 0..(B*16). - # So it sees all K_ctx where index < B * 16. - # Which is equivalent to block_ids[j] < block_ids[i]. - - # Broadcast logic - # shape [L, 1] - q_block_ids = block_ids.unsqueeze(1) - # shape [1, L] - k_block_ids = block_ids.unsqueeze(0) - - # Mask: 1 if K's block is strictly less than Q's block - # This gives access to all PREVIOUS blocks' context. - ctx_mask = k_block_ids < q_block_ids - - # 2. Noise Mask (L x L) - Right half of K - # Standard Causal Mask WITHIN the same block. - # Q[i] can see K_noise[j] if: - # a) Same Block: Block(Q[i]) == Block(K_noise[j]) - # b) Causal: j <= i - # Different blocks cannot see each other's noise. - - same_block = q_block_ids == k_block_ids - - noise_mask = same_block - - # Combine [Ctx_Mask, Noise_Mask] - # Shape [L, 2L] - # We need float mask for attention: 0.0 for allow, -inf for mask - # Transformers usually handles boolean masks by converting them, - # but explicit MinValue is safer if passing to generic attention. - # However, most HF models accept boolean [batch, 1, Q, K]. - - full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) - - # Check standard HF format: usually 1 for keep, 0 for mask, or boolean. - # Qwen3 implementation likely uses typical attention_mask handling. - # Let's return boolean for now, the model wrapper usually handles conversion - # or we check Qwen3DFlashAttention source usage. - # Looking at Qwen3DFlashAttention: `attn_output = attn_fn(..., attention_mask, ...)` - # If using SDPA, it expects boolean or float mask. - # If we look at `modeling_qwen3.py` (standard), it usually employs `_prepare_4d_causal_attention_mask`. - # But here we pass it explicitly. - # To be safe with `eager_attention_forward` and `SDPA`, we typically want: - # 0.0 for unmasked, -inf for masked. - - dtype = ( - torch.bfloat16 - ) # or get from device/model, but we return a tensor builder - # We will cast later or return boolean and let logic handle it? - # Safe bet: Return typical extended attention mask format: 0.0 for keep, min_dtype for remove. - - # But wait, Qwen3DFlashAttention passes attention_mask directly to attn_fn. - # If attn_fn is SDPA, it handles boolean. - # Let's return a float mask: 0.0 for True, -inf for False. - - full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32) - full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) - - return full_mask - def create_dflash_loss_mask( seq_len: int, block_size: int, device: torch.device ) -> torch.Tensor: - """ - Create DFlash-specific loss mask. - Excludes Block 0 and first position of each block. - """ + """Create DFlash loss mask: excludes block 0 and first position of each block.""" positions = torch.arange(seq_len, device=device) block_ids = positions // block_size