diff --git a/docs/advance/mtp.md b/docs/advance/mtp.md
index b342b670d74..29456af8b9a 100644
--- a/docs/advance/mtp.md
+++ b/docs/advance/mtp.md
@@ -2,7 +2,7 @@
**Author**: `https://github.com/meituan-search`
-Last updated: 02/15/2026
+Last updated: 04/08/2026
# 1. Scope of Support
@@ -12,6 +12,8 @@ Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek s
- **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list;
+- **vLLM Inference-Only Speculative Decoding**: In addition to MTP-based rollout acceleration, vLLM also supports inference-only speculative decoding with external `EAGLE` / `EAGLE3` draft models. This path is currently supported only for vLLM rollout. SGLang is not supported right now;
+
- **Dependency Versions**:
- mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (Already merged into the main branch);
@@ -31,7 +33,8 @@ The MTP training process can be flexibly controlled through the following config
| Load MTP Parameters Only | `enable=True` | VRAM usage will increase, but the exported parameters include the MTP module and can be directly used for online deployment |
| Full-Parameter MTP Training | `enable=True`
`enable_train=True`
`mtp_loss_scaling_factor=0.1` | MTP Loss will apply to all model parameters |
| MTP Parameter-Only Training | `enable=True`
`enable_train=True`
`detach_encoder=True` | Freeze the Encoder layer, update only MTP module parameters, MTP Loss applies only to MTP parameters |
-| MTP Accelerated Rollout | 1. vLLM configuration:
`enable=True`
`enable_rollout=True`
`method="mtp"`
`num_speculative_tokens=1`
2. SGLang configuration:
`enable=True`
`enable_rollout=True`
`speculative_algorithm="EAGLE"`
`speculative_num_steps=2`
`speculative_eagle_topk=2`
`speculative_num_draft_tokens=4` | Achieve inference acceleration during the Rollout phase based on MTP |
+| MTP Accelerated Rollout | 1. vLLM configuration:
`enable=True`
`enable_rollout=True`
`method="mtp"`
`num_speculative_tokens=1`
2. SGLang configuration:
`enable=True`
`enable_rollout=True`
`speculative_algorithm="EAGLE"`
`speculative_num_steps=2`
`speculative_eagle_topk=2`
`speculative_num_draft_tokens=4` | Achieve inference acceleration during the Rollout phase based on trainable MTP parameters |
+| vLLM Inference-Only EAGLE / EAGLE3 Rollout Acceleration | `actor_rollout_ref.rollout.speculative_decoding.enable=True`
`actor_rollout_ref.rollout.speculative_decoding.method="EAGLE"` or `"EAGLE3"`
`actor_rollout_ref.rollout.speculative_decoding.draft_model_path=/path/to/draft/model`
`actor_rollout_ref.rollout.speculative_decoding.num_draft_tokens=4`
`actor_rollout_ref.rollout.speculative_decoding.draft_tensor_parallel_size=1` or `tensor_model_parallel_size` | Achieve rollout acceleration on vLLM with an external draft model. This does not require trainable MTP parameters. SGLang does not support this path right now. |
# 3. Experimental Results
@@ -109,4 +112,3 @@ The experiment was conducted using following data:
The result: [wandb link](https://wandb.ai/hou-zg-meituan/mimo-7b-sft-mtp?nw=nwuserhouzg)
The presence of mtp layer has limited effect on main loss. However, when MTP layer is detached, the mtp_loss converges to a higher value.
-
diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py
index 1457ac6fccc..c16bf631a6a 100644
--- a/verl/experimental/agent_loop/agent_loop.py
+++ b/verl/experimental/agent_loop/agent_loop.py
@@ -30,7 +30,10 @@
from tensordict import TensorDict
from transformers import AutoProcessor, AutoTokenizer
-from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
+from verl.experimental.agent_loop.prometheus_utils import (
+ read_spec_decoding_metrics_from_prometheus,
+ update_prometheus_config,
+)
from verl.experimental.agent_loop.utils import resolve_config_path
from verl.experimental.teacher_loop import TeacherModelManager
from verl.protocol import DataProto
@@ -41,18 +44,9 @@
from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class
from verl.utils.model import compute_position_id_with_mask
from verl.utils.ray_utils import auto_await, get_event_loop
-from verl.utils.rollout_trace import (
- RolloutTraceConfig,
- rollout_trace_attr,
- rollout_trace_op,
-)
+from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op
from verl.utils.tokenizer import normalize_token_ids
-from verl.workers.config import (
- DistillationConfig,
- DistillationLossConfig,
- HFModelConfig,
- RolloutConfig,
-)
+from verl.workers.config import DistillationConfig, DistillationLossConfig, HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import DiffusionOutput, TokenOutput, get_rollout_replica_class
logger = logging.getLogger(__file__)
@@ -1153,6 +1147,14 @@ async def generate_sequences(self, prompts: DataProto) -> DataProto:
"""
if self.stream_teacher_with_rollout:
await self.teacher_model_manager.wake_up()
+
+ spec_before = None
+ if self.rollout_config.name == "vllm" and self.rollout_config.speculative_decoding.enable:
+ try:
+ spec_before = await read_spec_decoding_metrics_from_prometheus(self.server_addresses)
+ except Exception as e:
+ print(f"speculative decoding unavailable: {e}")
+
chunkes = prompts.chunk(len(self.agent_loop_workers))
outputs = await asyncio.gather(
*[
@@ -1169,6 +1171,33 @@ async def generate_sequences(self, prompts: DataProto) -> DataProto:
timing = self._performance_metrics(metrics, output)
output.meta_info = {"timing": timing, **outputs[0].meta_info}
+
+ if spec_before is not None:
+ try:
+ spec_after = await read_spec_decoding_metrics_from_prometheus(self.server_addresses)
+ spec_delta = {key: spec_after[key] - spec_before[key] for key in spec_before}
+ acceptance_rate = (
+ spec_delta["num_accepted_tokens"] / spec_delta["num_draft_tokens"]
+ if spec_delta["num_draft_tokens"] > 0
+ else float("inf")
+ )
+
+ mean_acceptance_length = (
+ 1.0 + (spec_delta["num_accepted_tokens"] / spec_delta["num_drafts"])
+ if spec_delta["num_drafts"] > 0
+ else 1.0
+ )
+
+ output.meta_info["speculative_decoding_metrics"] = {
+ "num_drafts": spec_delta["num_drafts"],
+ "num_draft_tokens": spec_delta["num_draft_tokens"],
+ "num_accepted_tokens": spec_delta["num_accepted_tokens"],
+ "avg_draft_acceptance_rate": acceptance_rate,
+ "mean_acceptance_length": mean_acceptance_length,
+ }
+ except Exception as e:
+ print(f"speculative decoding unavailable: {e}")
+
return output
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
diff --git a/verl/experimental/agent_loop/prometheus_utils.py b/verl/experimental/agent_loop/prometheus_utils.py
index 0ce582df61e..167d2530791 100644
--- a/verl/experimental/agent_loop/prometheus_utils.py
+++ b/verl/experimental/agent_loop/prometheus_utils.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import asyncio
import logging
import os
@@ -108,3 +108,46 @@ def reload_prometheus(port):
except Exception as e:
logger.error(f"Failed to update Prometheus configuration: {e}")
+
+
+def _read_spec_decoding_metrics_from_prometheus_for_address(address: str) -> dict[str, float]:
+ import requests
+ from prometheus_client.parser import text_string_to_metric_families
+
+ metric_name_to_key = {
+ "vllm:spec_decode_num_drafts_total": "num_drafts",
+ "vllm:spec_decode_num_draft_tokens_total": "num_draft_tokens",
+ "vllm:spec_decode_num_accepted_tokens_total": "num_accepted_tokens",
+ }
+ totals = {key: 0.0 for key in metric_name_to_key.values()}
+ session = requests.Session()
+ session.trust_env = False
+
+ metrics_text = session.get(f"http://{address}/metrics", timeout=5).text
+ for family in text_string_to_metric_families(metrics_text):
+ for sample in family.samples:
+ key = metric_name_to_key.get(sample.name)
+ if key is not None:
+ totals[key] += float(sample.value)
+ return totals
+
+
+async def read_spec_decoding_metrics_from_prometheus(server_adresses: list[str]) -> dict[str, float]:
+ totals = {
+ "num_drafts": 0.0,
+ "num_draft_tokens": 0.0,
+ "num_accepted_tokens": 0.0,
+ }
+
+ results = await asyncio.gather(
+ *[
+ asyncio.to_thread(_read_spec_decoding_metrics_from_prometheus_for_address, address)
+ for address in server_adresses
+ ]
+ )
+
+ for metrics in results:
+ for key, value in metrics.items():
+ totals[key] += value
+
+ return totals
diff --git a/verl/trainer/config/_generated_diffusion_trainer.yaml b/verl/trainer/config/_generated_diffusion_trainer.yaml
index 3f9ef3e41c8..10640606970 100644
--- a/verl/trainer/config/_generated_diffusion_trainer.yaml
+++ b/verl/trainer/config/_generated_diffusion_trainer.yaml
@@ -348,6 +348,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ speculative_decoding:
+ _target_: verl.workers.config.rollout.SpeculativeDecodingConfig
+ enable: false
+ method: eagle3
+ num_steps: 1
+ num_draft_tokens: 4
+ draft_model_path: null
+ draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
height: 512
width: 512
diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
index e6fc322e021..aa74bb5cfd7 100644
--- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
@@ -368,6 +368,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ speculative_decoding:
+ _target_: verl.workers.config.rollout.SpeculativeDecodingConfig
+ enable: false
+ method: eagle3
+ num_steps: 1
+ num_draft_tokens: 4
+ draft_model_path: null
+ draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
index a37acd9abe8..a6b8dd9ecf1 100644
--- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
@@ -335,6 +335,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ speculative_decoding:
+ _target_: verl.workers.config.rollout.SpeculativeDecodingConfig
+ enable: false
+ method: eagle3
+ num_steps: 1
+ num_draft_tokens: 4
+ draft_model_path: null
+ draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml
index 5009fbaaf85..9d2ed2858e7 100644
--- a/verl/trainer/config/_generated_ppo_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_trainer.yaml
@@ -344,6 +344,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ speculative_decoding:
+ _target_: verl.workers.config.rollout.SpeculativeDecodingConfig
+ enable: false
+ method: eagle3
+ num_steps: 1
+ num_draft_tokens: 4
+ draft_model_path: null
+ draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml
index a44d2e2140e..a30d1c012b2 100644
--- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml
@@ -314,6 +314,14 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ speculative_decoding:
+ _target_: verl.workers.config.rollout.SpeculativeDecodingConfig
+ enable: false
+ method: eagle3
+ num_steps: 1
+ num_draft_tokens: 4
+ draft_model_path: null
+ draft_tensor_parallel_size: 1
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
model:
_target_: verl.workers.config.HFModelConfig
diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml
index 1c4b99c416c..c2c56c10e93 100644
--- a/verl/trainer/config/rollout/rollout.yaml
+++ b/verl/trainer/config/rollout/rollout.yaml
@@ -424,5 +424,29 @@ quantization_config_file: null
# MTP configuration, reuse model configuration
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+# Speculative decoding configuration for vLLM rollout using an external draft model.
+speculative_decoding:
+
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
+ _target_: verl.workers.config.rollout.SpeculativeDecodingConfig
+
+ # Whether to enable inference-only speculative decoding for vLLM rollout
+ enable: False
+
+ # Speculative decoding method supported by vLLM, e.g. eagle or eagle3
+ method: eagle3
+
+ # Number of speculative decoding steps
+ num_steps: 1
+
+ # Number of draft tokens proposed by the draft model
+ num_draft_tokens: 4
+
+ # Path to the draft model used for speculative decoding
+ draft_model_path: null
+
+ # Tensor parallel size for the draft model, should be 1 or match tensor_model_parallel_size
+ draft_tensor_parallel_size: 1
+
# QAT configuration (inherited from actor's engine config)
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}
diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py
index 4dd7d2d00a5..f0538a5af5a 100644
--- a/verl/trainer/ppo/metric_utils.py
+++ b/verl/trainer/ppo/metric_utils.py
@@ -222,6 +222,24 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
metrics["tool_call_counts/max"] = tool_call_counts.max()
metrics["tool_call_counts/mean"] = tool_call_counts.mean()
+ # speculative decoding
+ if "speculative_decoding_metrics" in batch.meta_info:
+ metrics["speculative_decoding/num_drafts"] = np.mean(
+ batch.meta_info["speculative_decoding_metrics"]["num_drafts"]
+ )
+ metrics["speculative_decoding/num_draft_tokens"] = np.mean(
+ batch.meta_info["speculative_decoding_metrics"]["num_draft_tokens"]
+ )
+ metrics["speculative_decoding/num_accepted_tokens"] = np.mean(
+ batch.meta_info["speculative_decoding_metrics"]["num_accepted_tokens"]
+ )
+ metrics["speculative_decoding/avg_draft_acceptance_rate"] = np.mean(
+ batch.meta_info["speculative_decoding_metrics"]["avg_draft_acceptance_rate"]
+ )
+ metrics["speculative_decoding/mean_acceptance_length"] = np.mean(
+ batch.meta_info["speculative_decoding_metrics"]["mean_acceptance_length"]
+ )
+
return metrics
diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py
index 4075c994af3..4f89cdc452b 100644
--- a/verl/workers/config/rollout.py
+++ b/verl/workers/config/rollout.py
@@ -33,6 +33,7 @@
"RolloutConfig",
"DiffusionRolloutConfig",
"CheckpointEngineConfig",
+ "SpeculativeDecodingConfig",
"SkipConfig",
]
@@ -165,6 +166,18 @@ class CheckpointEngineConfig(BaseConfig):
custom_backend_module: Optional[str] = None
+@dataclass
+class SpeculativeDecodingConfig(BaseConfig):
+ enable: bool = False
+
+ method: str = "eagle3"
+ num_steps: int = 3
+ num_draft_tokens: int = 4
+ draft_model_path: str | None = None
+
+ draft_tensor_parallel_size: int = 1
+
+
@dataclass
class RolloutConfig(BaseConfig):
_mutable_fields = {
@@ -285,6 +298,8 @@ class RolloutConfig(BaseConfig):
mtp: MtpConfig = field(default_factory=MtpConfig)
+ speculative_decoding: SpeculativeDecodingConfig = field(default_factory=SpeculativeDecodingConfig)
+
qat: Optional[dict] = None
def __post_init__(self):
@@ -334,6 +349,46 @@ def __post_init__(self):
f"Current rollout {self.name=} not implemented pipeline_model_parallel_size > 1 yet."
)
+ if self.name != "vllm" and self.speculative_decoding.enable:
+ raise NotImplementedError(
+ f"Rollout {self.name=} does not support speculative decoding "
+ f"{self.speculative_decoding.method=} for rollout acceleration yet"
+ )
+
+ if self.name == "vllm" and self.speculative_decoding.enable:
+ if self.speculative_decoding.method.lower() not in {"eagle", "eagle3"}:
+ warnings.warn(
+ "Speculative decoding methods other than 'eagle' and 'eagle3' are untested and may be buggy ",
+ stacklevel=2,
+ )
+
+ if not (
+ self.speculative_decoding.draft_tensor_parallel_size == self.tensor_model_parallel_size
+ or self.speculative_decoding.draft_tensor_parallel_size == 1
+ ):
+ raise ValueError(
+ f"draft_tensor_parallel_size={self.speculative_decoding.draft_tensor_parallel_size} "
+ "cannot be other value than 1 or target model "
+ "tensor_parallel_size={self.tensor_model_parallel_size} "
+ )
+
+ if self.speculative_decoding.method.lower() in {"eagle", "eagle3"} and (
+ self.enable_chunked_prefill or self.enable_prefix_caching or not self.enforce_eager
+ ):
+ warnings.warn(
+ "vLLM speculative decoding with EAGLE/EAGLE3 may regress throughput when "
+ "enable_chunked_prefill=True, enable_prefix_caching=True, or enforce_eager=False. "
+ "Overriding to enable_chunked_prefill=False, enable_prefix_caching=False, "
+ "and enforce_eager=True for now.",
+ stacklevel=2,
+ )
+ self.enable_chunked_prefill = False
+ self.enable_prefix_caching = False
+ self.enforce_eager = True
+
+ if self.speculative_decoding.enable and self.mtp.enable_rollout:
+ raise ValueError("Use either speculative_decoding or mtp, but not both simultaneously")
+
@dataclass
class DiffusionRolloutConfig(RolloutConfig):
diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py
index b2cc699f6ab..12bf30b17ce 100644
--- a/verl/workers/rollout/vllm_rollout/utils.py
+++ b/verl/workers/rollout/vllm_rollout/utils.py
@@ -176,7 +176,13 @@ def monkey_patch_model(self, vocab_size: int):
# patch weight loader to support MoE model
patch_vllm_moe_model_weight_loader(self.model_runner.model)
- def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False, use_shm: bool = False):
+ def update_weights_from_ipc(
+ self,
+ peft_config: dict = None,
+ base_sync_done=False,
+ use_shm: bool = False,
+ use_speculative_decoding: bool = False,
+ ):
"""Update the weights of the rollout model."""
from vllm.platforms import current_platform
@@ -216,10 +222,27 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
)
receiver.receive_weights(
on_bucket_received=lambda weights: self._update_weights(
- weights, peft_config=peft_config, base_sync_done=base_sync_done
+ weights,
+ peft_config=peft_config,
+ base_sync_done=base_sync_done,
)
)
+ if use_speculative_decoding:
+ # Reload draft weights because they are discarded after each model load.
+ from vllm.model_executor.model_loader import get_model_loader
+
+ loader = get_model_loader(self.model_runner.drafter.vllm_config.load_config)
+ self.model_runner.drafter.model.load_weights(
+ loader.get_all_weights(
+ self.vllm_config.speculative_config.draft_model_config,
+ self.model_runner.drafter.model,
+ )
+ )
+
+ # Rebuild RoPE caches because reloading weights clears the cos/sin cache.
+ rebuild_rope_caches(self.model_runner.drafter.model)
+
if self._is_qat_model:
# QAT (compressed-tensors): call process_weights_after_loading AFTER all buckets are received
from verl.utils.qat import manual_process_weights_after_loading
@@ -239,6 +262,11 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
model_config = self.model_runner.vllm_config.model_config
process_weights_after_loading(model, model_config, self.device)
+ if use_speculative_decoding:
+ drafter_model = self.model_runner.drafter.model
+ drafter_model_config = self.model_runner.drafter.vllm_config.model_config
+ process_weights_after_loading(drafter_model, drafter_model_config, self.device)
+
def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: dict, base_sync_done: bool):
if peft_config and base_sync_done:
weights = dict(weights)
@@ -292,7 +320,13 @@ def __new__(cls, **kwargs):
return super().__new__(cls)
- def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False, use_shm: bool = False):
+ def update_weights_from_ipc(
+ self,
+ peft_config: dict = None,
+ base_sync_done=False,
+ use_shm: bool = False,
+ use_speculative_decoding: bool = False,
+ ):
"""Update the weights of the rollout model."""
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import BucketedWeightReceiver
@@ -428,3 +462,13 @@ def extract_prompt_logprobs(output: RequestOutput, num_prompt_logprobs: Optional
result_dict["prompt_ids"] = prompt_ids_ls
result_dict["prompt_logprobs"] = prompt_logprobs_ls
+
+
+@torch.no_grad()
+def rebuild_rope_caches(root_module: torch.nn.Module):
+ for _, m in root_module.named_modules():
+ if hasattr(m, "rotary_emb"):
+ old = m.rotary_emb.cos_sin_cache
+ cache = m.rotary_emb._compute_cos_sin_cache()
+ cache = cache.to(device=old.device, dtype=old.dtype)
+ m.rotary_emb.cos_sin_cache = cache
diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py
index f7c1164e2d8..bfa8dd5d7a4 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py
@@ -290,6 +290,23 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
}
args["speculative_config"] = speculative_config
+ # speculative decoding:
+ if self.config.speculative_decoding.enable:
+ if self.config.speculative_decoding.draft_model_path is None:
+ raise ValueError(
+ "self.config.speculative_decoding._draft_model_path shoul not be None when using with vLLM"
+ )
+
+ speculative_config = {
+ "model": self.config.speculative_decoding.draft_model_path,
+ "max_model_len": self.config.max_model_len,
+ "num_speculative_tokens": self.config.speculative_decoding.num_draft_tokens,
+ "method": self.config.speculative_decoding.method.lower(),
+ "draft_tensor_parallel_size": self.config.speculative_decoding.draft_tensor_parallel_size,
+ }
+
+ args["speculative_config"] = speculative_config
+
if self.config.data_parallel_size > 1:
assert self.gpus_per_node % self.config.tensor_model_parallel_size == 0, (
"gpus_per_node should be divisible by tensor_model_parallel_size"
diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py
index 90641b575a2..0a854df3c7a 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py
@@ -98,6 +98,7 @@ def __init__(
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock"
self.use_shm = not is_support_ipc()
+ self.use_speculative_decoding = config.speculative_decoding.enable
if self.use_shm:
logger.warning(
"IPC is not supported on your devices. Falling back to shared memory for weight transfer, "
@@ -160,7 +161,7 @@ async def update_weights(
future = await self._execute_method(
"update_weights_from_ipc",
non_block=True,
- kwargs={**kwargs, "use_shm": self.use_shm},
+ kwargs={**kwargs, "use_shm": self.use_shm, "use_speculative_decoding": self.use_speculative_decoding},
)
bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes