Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/advance/mtp.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author**: `https://github.com/meituan-search`

Last updated: 02/15/2026
Last updated: 04/08/2026

# 1. Scope of Support

Expand All @@ -12,6 +12,8 @@ Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek s

- **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list;

- **vLLM Inference-Only Speculative Decoding**: In addition to MTP-based rollout acceleration, vLLM also supports inference-only speculative decoding with external `EAGLE` / `EAGLE3` draft models. This path is currently supported only for vLLM rollout. SGLang is not supported right now;

- **Dependency Versions**:

- mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (Already merged into the main branch);
Expand All @@ -31,7 +33,8 @@ The MTP training process can be flexibly controlled through the following config
| Load MTP Parameters Only | `enable=True` | VRAM usage will increase, but the exported parameters include the MTP module and can be directly used for online deployment |
| Full-Parameter MTP Training | `enable=True`<br>`enable_train=True`<br>`mtp_loss_scaling_factor=0.1` | MTP Loss will apply to all model parameters |
| MTP Parameter-Only Training | `enable=True`<br>`enable_train=True`<br>`detach_encoder=True` | Freeze the Encoder layer, update only MTP module parameters, MTP Loss applies only to MTP parameters |
| MTP Accelerated Rollout | 1. vLLM configuration:<br>`enable=True`<br>`enable_rollout=True`<br>`method="mtp"`<br>`num_speculative_tokens=1`<br>2. SGLang configuration:<br>`enable=True`<br>`enable_rollout=True`<br>`speculative_algorithm="EAGLE"`<br>`speculative_num_steps=2`<br>`speculative_eagle_topk=2`<br>`speculative_num_draft_tokens=4` | Achieve inference acceleration during the Rollout phase based on MTP |
| MTP Accelerated Rollout | 1. vLLM configuration:<br>`enable=True`<br>`enable_rollout=True`<br>`method="mtp"`<br>`num_speculative_tokens=1`<br>2. SGLang configuration:<br>`enable=True`<br>`enable_rollout=True`<br>`speculative_algorithm="EAGLE"`<br>`speculative_num_steps=2`<br>`speculative_eagle_topk=2`<br>`speculative_num_draft_tokens=4` | Achieve inference acceleration during the Rollout phase based on trainable MTP parameters |
| vLLM Inference-Only EAGLE / EAGLE3 Rollout Acceleration | `actor_rollout_ref.rollout.speculative_decoding.enable=True`<br>`actor_rollout_ref.rollout.speculative_decoding.method="EAGLE"` or `"EAGLE3"`<br>`actor_rollout_ref.rollout.speculative_decoding.draft_model_path=/path/to/draft/model`<br>`actor_rollout_ref.rollout.speculative_decoding.num_draft_tokens=4`<br>`actor_rollout_ref.rollout.speculative_decoding.draft_tensor_parallel_size=1` or `tensor_model_parallel_size` | Achieve rollout acceleration on vLLM with an external draft model. This does not require trainable MTP parameters. SGLang does not support this path right now. |

# 3. Experimental Results

Expand Down Expand Up @@ -109,4 +112,3 @@ The experiment was conducted using following data:
The result: [wandb link](https://wandb.ai/hou-zg-meituan/mimo-7b-sft-mtp?nw=nwuserhouzg)

The presence of mtp layer has limited effect on main loss. However, when MTP layer is detached, the mtp_loss converges to a higher value.

53 changes: 41 additions & 12 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
from tensordict import TensorDict
from transformers import AutoProcessor, AutoTokenizer

from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
from verl.experimental.agent_loop.prometheus_utils import (
read_spec_decoding_metrics_from_prometheus,
update_prometheus_config,
)
from verl.experimental.agent_loop.utils import resolve_config_path
from verl.experimental.teacher_loop import TeacherModelManager
from verl.protocol import DataProto
Expand All @@ -41,18 +44,9 @@
from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class
from verl.utils.model import compute_position_id_with_mask
from verl.utils.ray_utils import auto_await, get_event_loop
from verl.utils.rollout_trace import (
RolloutTraceConfig,
rollout_trace_attr,
rollout_trace_op,
)
from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op
from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import (
DistillationConfig,
DistillationLossConfig,
HFModelConfig,
RolloutConfig,
)
from verl.workers.config import DistillationConfig, DistillationLossConfig, HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import DiffusionOutput, TokenOutput, get_rollout_replica_class

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -1153,6 +1147,14 @@ async def generate_sequences(self, prompts: DataProto) -> DataProto:
"""
if self.stream_teacher_with_rollout:
await self.teacher_model_manager.wake_up()

spec_before = None
if self.rollout_config.name == "vllm" and self.rollout_config.speculative_decoding.enable:
try:
spec_before = await read_spec_decoding_metrics_from_prometheus(self.server_addresses)
except Exception as e:
print(f"speculative decoding unavailable: {e}")

chunkes = prompts.chunk(len(self.agent_loop_workers))
outputs = await asyncio.gather(
*[
Expand All @@ -1169,6 +1171,33 @@ async def generate_sequences(self, prompts: DataProto) -> DataProto:
timing = self._performance_metrics(metrics, output)

output.meta_info = {"timing": timing, **outputs[0].meta_info}

if spec_before is not None:
try:
spec_after = await read_spec_decoding_metrics_from_prometheus(self.server_addresses)
spec_delta = {key: spec_after[key] - spec_before[key] for key in spec_before}
acceptance_rate = (
spec_delta["num_accepted_tokens"] / spec_delta["num_draft_tokens"]
if spec_delta["num_draft_tokens"] > 0
else float("inf")
Comment on lines +1180 to +1182
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for acceptance_rate when spec_delta["num_draft_tokens"] is 0 results in float("inf"). This is likely incorrect and can cause issues with metric aggregation (e.g., np.mean over inf will be inf). When no draft tokens are generated, the acceptance rate should be 0.0.

Suggested change
spec_delta["num_accepted_tokens"] / spec_delta["num_draft_tokens"]
if spec_delta["num_draft_tokens"] > 0
else float("inf")
spec_delta["num_accepted_tokens"] / spec_delta["num_draft_tokens"]
if spec_delta["num_draft_tokens"] > 0
else 0.0

)

mean_acceptance_length = (
1.0 + (spec_delta["num_accepted_tokens"] / spec_delta["num_drafts"])
if spec_delta["num_drafts"] > 0
else 1.0
)

output.meta_info["speculative_decoding_metrics"] = {
"num_drafts": spec_delta["num_drafts"],
"num_draft_tokens": spec_delta["num_draft_tokens"],
"num_accepted_tokens": spec_delta["num_accepted_tokens"],
"avg_draft_acceptance_rate": acceptance_rate,
"mean_acceptance_length": mean_acceptance_length,
}
except Exception as e:
print(f"speculative decoding unavailable: {e}")

return output

def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
Expand Down
45 changes: 44 additions & 1 deletion verl/experimental/agent_loop/prometheus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import logging
import os

Expand Down Expand Up @@ -108,3 +108,46 @@ def reload_prometheus(port):

except Exception as e:
logger.error(f"Failed to update Prometheus configuration: {e}")


def _read_spec_decoding_metrics_from_prometheus_for_address(address: str) -> dict[str, float]:
import requests
from prometheus_client.parser import text_string_to_metric_families

metric_name_to_key = {
"vllm:spec_decode_num_drafts_total": "num_drafts",
"vllm:spec_decode_num_draft_tokens_total": "num_draft_tokens",
"vllm:spec_decode_num_accepted_tokens_total": "num_accepted_tokens",
}
totals = {key: 0.0 for key in metric_name_to_key.values()}
session = requests.Session()
session.trust_env = False

metrics_text = session.get(f"http://{address}/metrics", timeout=5).text
for family in text_string_to_metric_families(metrics_text):
for sample in family.samples:
key = metric_name_to_key.get(sample.name)
if key is not None:
totals[key] += float(sample.value)
return totals


async def read_spec_decoding_metrics_from_prometheus(server_adresses: list[str]) -> dict[str, float]:
totals = {
"num_drafts": 0.0,
"num_draft_tokens": 0.0,
"num_accepted_tokens": 0.0,
}

results = await asyncio.gather(
*[
asyncio.to_thread(_read_spec_decoding_metrics_from_prometheus_for_address, address)
for address in server_adresses
]
)

for metrics in results:
for key, value in metrics.items():
totals[key] += value

return totals
8 changes: 8 additions & 0 deletions verl/trainer/config/_generated_diffusion_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
speculative_decoding:
_target_: verl.workers.config.rollout.SpeculativeDecodingConfig
enable: false
method: eagle3
num_steps: 1
num_draft_tokens: 4
draft_model_path: null
draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
height: 512
width: 512
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
speculative_decoding:
_target_: verl.workers.config.rollout.SpeculativeDecodingConfig
enable: false
method: eagle3
num_steps: 1
num_draft_tokens: 4
draft_model_path: null
draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
speculative_decoding:
_target_: verl.workers.config.rollout.SpeculativeDecodingConfig
enable: false
method: eagle3
num_steps: 1
num_draft_tokens: 4
draft_model_path: null
draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
speculative_decoding:
_target_: verl.workers.config.rollout.SpeculativeDecodingConfig
enable: false
method: eagle3
num_steps: 1
num_draft_tokens: 4
draft_model_path: null
draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
speculative_decoding:
_target_: verl.workers.config.rollout.SpeculativeDecodingConfig
enable: false
method: eagle3
num_steps: 1
num_draft_tokens: 4
draft_model_path: null
draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
Expand Down
24 changes: 24 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -424,5 +424,29 @@ quantization_config_file: null
# MTP configuration, reuse model configuration
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}

# Speculative decoding configuration for vLLM rollout using an external draft model.
speculative_decoding:

# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.rollout.SpeculativeDecodingConfig

# Whether to enable inference-only speculative decoding for vLLM rollout
enable: False

# Speculative decoding method supported by vLLM, e.g. eagle or eagle3
method: eagle3

# Number of speculative decoding steps
num_steps: 1

# Number of draft tokens proposed by the draft model
num_draft_tokens: 4

# Path to the draft model used for speculative decoding
draft_model_path: null

# Tensor parallel size for the draft model, should be 1 or match tensor_model_parallel_size
draft_tensor_parallel_size: 1

# QAT configuration (inherited from actor's engine config)
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
18 changes: 18 additions & 0 deletions verl/trainer/ppo/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,24 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
metrics["tool_call_counts/max"] = tool_call_counts.max()
metrics["tool_call_counts/mean"] = tool_call_counts.mean()

# speculative decoding
if "speculative_decoding_metrics" in batch.meta_info:
metrics["speculative_decoding/num_drafts"] = np.mean(
batch.meta_info["speculative_decoding_metrics"]["num_drafts"]
)
metrics["speculative_decoding/num_draft_tokens"] = np.mean(
batch.meta_info["speculative_decoding_metrics"]["num_draft_tokens"]
)
metrics["speculative_decoding/num_accepted_tokens"] = np.mean(
batch.meta_info["speculative_decoding_metrics"]["num_accepted_tokens"]
)
metrics["speculative_decoding/avg_draft_acceptance_rate"] = np.mean(
batch.meta_info["speculative_decoding_metrics"]["avg_draft_acceptance_rate"]
)
metrics["speculative_decoding/mean_acceptance_length"] = np.mean(
batch.meta_info["speculative_decoding_metrics"]["mean_acceptance_length"]
)

return metrics


Expand Down
55 changes: 55 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"RolloutConfig",
"DiffusionRolloutConfig",
"CheckpointEngineConfig",
"SpeculativeDecodingConfig",
"SkipConfig",
]

Expand Down Expand Up @@ -165,6 +166,18 @@ class CheckpointEngineConfig(BaseConfig):
custom_backend_module: Optional[str] = None


@dataclass
class SpeculativeDecodingConfig(BaseConfig):
enable: bool = False

method: str = "eagle3"
num_steps: int = 3
num_draft_tokens: int = 4
draft_model_path: str | None = None

draft_tensor_parallel_size: int = 1


@dataclass
class RolloutConfig(BaseConfig):
_mutable_fields = {
Expand Down Expand Up @@ -285,6 +298,8 @@ class RolloutConfig(BaseConfig):

mtp: MtpConfig = field(default_factory=MtpConfig)

speculative_decoding: SpeculativeDecodingConfig = field(default_factory=SpeculativeDecodingConfig)

qat: Optional[dict] = None

def __post_init__(self):
Expand Down Expand Up @@ -334,6 +349,46 @@ def __post_init__(self):
f"Current rollout {self.name=} not implemented pipeline_model_parallel_size > 1 yet."
)

if self.name != "vllm" and self.speculative_decoding.enable:
raise NotImplementedError(
f"Rollout {self.name=} does not support speculative decoding "
f"{self.speculative_decoding.method=} for rollout acceleration yet"
)

if self.name == "vllm" and self.speculative_decoding.enable:
if self.speculative_decoding.method.lower() not in {"eagle", "eagle3"}:
warnings.warn(
"Speculative decoding methods other than 'eagle' and 'eagle3' are untested and may be buggy ",
stacklevel=2,
)

if not (
self.speculative_decoding.draft_tensor_parallel_size == self.tensor_model_parallel_size
or self.speculative_decoding.draft_tensor_parallel_size == 1
):
raise ValueError(
f"draft_tensor_parallel_size={self.speculative_decoding.draft_tensor_parallel_size} "
"cannot be other value than 1 or target model "
"tensor_parallel_size={self.tensor_model_parallel_size} "
)
Comment on lines +369 to +373
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The f-string for the ValueError message is malformed. The variable self.tensor_model_parallel_size is outside the curly braces, so it will be printed literally instead of its value being interpolated. This will produce a confusing error message for users.

Suggested change
raise ValueError(
f"draft_tensor_parallel_size={self.speculative_decoding.draft_tensor_parallel_size} "
"cannot be other value than 1 or target model "
"tensor_parallel_size={self.tensor_model_parallel_size} "
)
raise ValueError(
f"draft_tensor_parallel_size={self.speculative_decoding.draft_tensor_parallel_size} "
"cannot be other value than 1 or target model "
f"tensor_parallel_size={self.tensor_model_parallel_size}"
)


if self.speculative_decoding.method.lower() in {"eagle", "eagle3"} and (
self.enable_chunked_prefill or self.enable_prefix_caching or not self.enforce_eager
):
warnings.warn(
"vLLM speculative decoding with EAGLE/EAGLE3 may regress throughput when "
"enable_chunked_prefill=True, enable_prefix_caching=True, or enforce_eager=False. "
"Overriding to enable_chunked_prefill=False, enable_prefix_caching=False, "
"and enforce_eager=True for now.",
stacklevel=2,
)
self.enable_chunked_prefill = False
self.enable_prefix_caching = False
self.enforce_eager = True

if self.speculative_decoding.enable and self.mtp.enable_rollout:
raise ValueError("Use either speculative_decoding or mtp, but not both simultaneously")


@dataclass
class DiffusionRolloutConfig(RolloutConfig):
Expand Down
Loading