Skip to content

[RFC] Add P-EAGLE support in training #292

@shanjiaz

Description

@shanjiaz

Overview

This RFC proposes implementing training support for P-EAGLE (Parallel EAGLE), a parallel speculative decoding method that extends EAGLE-3 with multi-token prediction, potentially offering 2-3x speedups over sequential EAGLE-3 drafting.

Background

P-EAGLE inherits EAGLE-3's architecture but introduces parallel prediction. The approach uses the same lightweight decoder architecture as EAGLE-3 but generates multiple token predictions in parallel through Conditional-On-Distribution (COD) sampling, rather than sequential test-time training steps.

"Training a parallel token prediction model requires extending each sequence of length n to accommodate K parallel prediction depths, where depth k predicts the token $k+1$ positions ahead. Without optimization, this creates $n\times K$ total positions with $\mathcal{O}((nK)^2)$ attention complexity, causing out-of-memory failures at long sequences." P-EAGLE addresses this through COD sampling and sequence partitioning to make training tractable.

Paper: P-EAGLE: Parallel EAGLE Speculative Decoding

Key Technical Differences from Eagle-3

Parallel Prediction Groups: P-EAGLE organizes predictions into para_num parallel groups (typically $K=8$) created via COD sampling. Each depth $k$ predicts the token $k+1$ positions ahead" - depth $0$ predicts $i+1$, depth 1 predicts $i+2$, up to depth $7$ predicting $i+8$. All depths are processed in a single forward pass. EAGLE-3 uses sequential ttt_steps where each step depends on previous predictions through autoregressive generation.

Conditional Drop-token (COD) Sampling: COD reduces the number of positions at each prediction depth (group) through geometric decay: depth $0$ retains all n positions, depth $1$ randomly retains $n \times r$ positions, depth $2$ retains $n \times r^2$, and so on, where $r \in (0,1)$ is the retention rate (down_sample_ratio). This reduces total positions from $n \times K$ to $n \times (1 + r + r^2 + ⋯ + r^{K-1})$, significantly reducing attention memory from $\mathcal{O}((nK)^2)$ to $\mathcal{O}((n\Sigma r^i)^2)$. P-EAGLE further filters sampled positions to only include valid training positions where loss_mask == 1.

Learnable Padding Parameters: MTP (Multi-Token Prediction) positions lack predicted tokens and hidden vectors from previous steps. P-EAGLE addresses this with two learnable parameters: a shared hidden state h_shared that substitutes for missing hidden vectors, and a mask token embedding that substitutes for unknown previous tokens. The shared hidden state (implemented as mask_hidden with shape $[1, 1, 3\times hidden size]$) is trained via backpropagation to represent unsampled positions in parallel groups, enabling all tokens to be generated in a single forward pass.

Implementation Requirements

Algorithmic Changes (relative to Eagle3)

  • COD sampling logic for parallel group generation, dynamically downsampling with exponential decay ratio
  • Parallel group loss computation across variable-length sequences
  • Sequence splitting for memory-efficient gradient accumulation
  • Loss mask updates to handle COD sampling
  • Attention mask construction for non-causal $[L, 2L]$(position length) patterns within groups
  • Data collation with mask_hidden padding for variable-length groups
  • Position ID tracking across sampled indices for each depth

Model and Configuration

  • P-EAGLE model registration with SpeculatorModel registry
  • P-EAGLE model definition with mask_hidden parameter (shape: $[1, 1, 3\times hidden size]$)
  • Configuration class supporting para_depths, down_sample_ratio, down_sample_ratio_min, ptd_token_id parameters
  • Reuse of existing EAGLE-3 decoder layers and architecture

Training Infrastructure

  • Group-specific hidden state and target extraction at sampled indices
  • Per-depth and aggregate top-k accuracy metrics
  • Training loop modifications for single-pass parallel group processing

Work Breakdown

Phase 1

  • P-EAGLE model and configuration class implementation reuse Eagle3
  • Add basic COD sampling implementation (maybe loss mask aware)
  • Add mask embeddings to P-eagle model
  • Loss mask update for mask tokens (using position ids)
  • Attention mask construction for parallel patterns, using flex attention
  • Add cross entropy loss function as an option in training script
  • Unfreeze embedding linear layer
  • Test that attention masks line up with author’s implementation

Phase 2

  • Dependency-aware sequence splitting for gradient accumulation
  • Training loop modifications for parallel group processing (grad accumulation support)
  • Documentation and training examples

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions