-
Notifications
You must be signed in to change notification settings - Fork 57
Description
Overview
This RFC proposes implementing training support for P-EAGLE (Parallel EAGLE), a parallel speculative decoding method that extends EAGLE-3 with multi-token prediction, potentially offering 2-3x speedups over sequential EAGLE-3 drafting.
Background
P-EAGLE inherits EAGLE-3's architecture but introduces parallel prediction. The approach uses the same lightweight decoder architecture as EAGLE-3 but generates multiple token predictions in parallel through Conditional-On-Distribution (COD) sampling, rather than sequential test-time training steps.
"Training a parallel token prediction model requires extending each sequence of length n to accommodate K parallel prediction depths, where depth k predicts the token
Paper: P-EAGLE: Parallel EAGLE Speculative Decoding
Key Technical Differences from Eagle-3
Parallel Prediction Groups: P-EAGLE organizes predictions into para_num parallel groups (typically ttt_steps where each step depends on previous predictions through autoregressive generation.
Conditional Drop-token (COD) Sampling: COD reduces the number of positions at each prediction depth (group) through geometric decay: depth down_sample_ratio). This reduces total positions from loss_mask == 1.
Learnable Padding Parameters: MTP (Multi-Token Prediction) positions lack predicted tokens and hidden vectors from previous steps. P-EAGLE addresses this with two learnable parameters: a shared hidden state h_shared that substitutes for missing hidden vectors, and a mask token embedding that substitutes for unknown previous tokens. The shared hidden state (implemented as mask_hidden with shape
Implementation Requirements
Algorithmic Changes (relative to Eagle3)
- COD sampling logic for parallel group generation, dynamically downsampling with exponential decay ratio
- Parallel group loss computation across variable-length sequences
- Sequence splitting for memory-efficient gradient accumulation
- Loss mask updates to handle COD sampling
- Attention mask construction for non-causal
$[L, 2L]$ (position length) patterns within groups - Data collation with mask_hidden padding for variable-length groups
- Position ID tracking across sampled indices for each depth
Model and Configuration
- P-EAGLE model registration with SpeculatorModel registry
- P-EAGLE model definition with mask_hidden parameter (shape:
$[1, 1, 3\times hidden size]$ ) - Configuration class supporting para_depths, down_sample_ratio, down_sample_ratio_min, ptd_token_id parameters
- Reuse of existing EAGLE-3 decoder layers and architecture
Training Infrastructure
- Group-specific hidden state and target extraction at sampled indices
- Per-depth and aggregate top-k accuracy metrics
- Training loop modifications for single-pass parallel group processing
Work Breakdown
Phase 1
- P-EAGLE model and configuration class implementation reuse Eagle3
- Add basic COD sampling implementation (maybe loss mask aware)
- Add mask embeddings to P-eagle model
- Loss mask update for mask tokens (using position ids)
- Attention mask construction for parallel patterns, using flex attention
- Add cross entropy loss function as an option in training script
- Unfreeze embedding linear layer
- Test that attention masks line up with author’s implementation
Phase 2
- Dependency-aware sequence splitting for gradient accumulation
- Training loop modifications for parallel group processing (grad accumulation support)
- Documentation and training examples