diff --git a/docs/advance/mtp.md b/docs/advance/mtp.md index b342b670d74..29456af8b9a 100644 --- a/docs/advance/mtp.md +++ b/docs/advance/mtp.md @@ -2,7 +2,7 @@ **Author**: `https://github.com/meituan-search` -Last updated: 02/15/2026 +Last updated: 04/08/2026 # 1. Scope of Support @@ -12,6 +12,8 @@ Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek s - **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list; +- **vLLM Inference-Only Speculative Decoding**: In addition to MTP-based rollout acceleration, vLLM also supports inference-only speculative decoding with external `EAGLE` / `EAGLE3` draft models. This path is currently supported only for vLLM rollout. SGLang is not supported right now; + - **Dependency Versions**: - mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (Already merged into the main branch); @@ -31,7 +33,8 @@ The MTP training process can be flexibly controlled through the following config | Load MTP Parameters Only | `enable=True` | VRAM usage will increase, but the exported parameters include the MTP module and can be directly used for online deployment | | Full-Parameter MTP Training | `enable=True`
`enable_train=True`
`mtp_loss_scaling_factor=0.1` | MTP Loss will apply to all model parameters | | MTP Parameter-Only Training | `enable=True`
`enable_train=True`
`detach_encoder=True` | Freeze the Encoder layer, update only MTP module parameters, MTP Loss applies only to MTP parameters | -| MTP Accelerated Rollout | 1. vLLM configuration:
`enable=True`
`enable_rollout=True`
`method="mtp"`
`num_speculative_tokens=1`
2. SGLang configuration:
`enable=True`
`enable_rollout=True`
`speculative_algorithm="EAGLE"`
`speculative_num_steps=2`
`speculative_eagle_topk=2`
`speculative_num_draft_tokens=4` | Achieve inference acceleration during the Rollout phase based on MTP | +| MTP Accelerated Rollout | 1. vLLM configuration:
`enable=True`
`enable_rollout=True`
`method="mtp"`
`num_speculative_tokens=1`
2. SGLang configuration:
`enable=True`
`enable_rollout=True`
`speculative_algorithm="EAGLE"`
`speculative_num_steps=2`
`speculative_eagle_topk=2`
`speculative_num_draft_tokens=4` | Achieve inference acceleration during the Rollout phase based on trainable MTP parameters | +| vLLM Inference-Only EAGLE / EAGLE3 Rollout Acceleration | `actor_rollout_ref.rollout.speculative_decoding.enable=True`
`actor_rollout_ref.rollout.speculative_decoding.method="EAGLE"` or `"EAGLE3"`
`actor_rollout_ref.rollout.speculative_decoding.draft_model_path=/path/to/draft/model`
`actor_rollout_ref.rollout.speculative_decoding.num_draft_tokens=4`
`actor_rollout_ref.rollout.speculative_decoding.draft_tensor_parallel_size=1` or `tensor_model_parallel_size` | Achieve rollout acceleration on vLLM with an external draft model. This does not require trainable MTP parameters. SGLang does not support this path right now. | # 3. Experimental Results @@ -109,4 +112,3 @@ The experiment was conducted using following data: The result: [wandb link](https://wandb.ai/hou-zg-meituan/mimo-7b-sft-mtp?nw=nwuserhouzg) The presence of mtp layer has limited effect on main loss. However, when MTP layer is detached, the mtp_loss converges to a higher value. - diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 1457ac6fccc..c16bf631a6a 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -30,7 +30,10 @@ from tensordict import TensorDict from transformers import AutoProcessor, AutoTokenizer -from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config +from verl.experimental.agent_loop.prometheus_utils import ( + read_spec_decoding_metrics_from_prometheus, + update_prometheus_config, +) from verl.experimental.agent_loop.utils import resolve_config_path from verl.experimental.teacher_loop import TeacherModelManager from verl.protocol import DataProto @@ -41,18 +44,9 @@ from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class from verl.utils.model import compute_position_id_with_mask from verl.utils.ray_utils import auto_await, get_event_loop -from verl.utils.rollout_trace import ( - RolloutTraceConfig, - rollout_trace_attr, - rollout_trace_op, -) +from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op from verl.utils.tokenizer import normalize_token_ids -from verl.workers.config import ( - DistillationConfig, - DistillationLossConfig, - HFModelConfig, - RolloutConfig, -) +from verl.workers.config import DistillationConfig, DistillationLossConfig, HFModelConfig, RolloutConfig from verl.workers.rollout.replica import DiffusionOutput, TokenOutput, get_rollout_replica_class logger = logging.getLogger(__file__) @@ -1153,6 +1147,14 @@ async def generate_sequences(self, prompts: DataProto) -> DataProto: """ if self.stream_teacher_with_rollout: await self.teacher_model_manager.wake_up() + + spec_before = None + if self.rollout_config.name == "vllm" and self.rollout_config.speculative_decoding.enable: + try: + spec_before = await read_spec_decoding_metrics_from_prometheus(self.server_addresses) + except Exception as e: + print(f"speculative decoding unavailable: {e}") + chunkes = prompts.chunk(len(self.agent_loop_workers)) outputs = await asyncio.gather( *[ @@ -1169,6 +1171,33 @@ async def generate_sequences(self, prompts: DataProto) -> DataProto: timing = self._performance_metrics(metrics, output) output.meta_info = {"timing": timing, **outputs[0].meta_info} + + if spec_before is not None: + try: + spec_after = await read_spec_decoding_metrics_from_prometheus(self.server_addresses) + spec_delta = {key: spec_after[key] - spec_before[key] for key in spec_before} + acceptance_rate = ( + spec_delta["num_accepted_tokens"] / spec_delta["num_draft_tokens"] + if spec_delta["num_draft_tokens"] > 0 + else float("inf") + ) + + mean_acceptance_length = ( + 1.0 + (spec_delta["num_accepted_tokens"] / spec_delta["num_drafts"]) + if spec_delta["num_drafts"] > 0 + else 1.0 + ) + + output.meta_info["speculative_decoding_metrics"] = { + "num_drafts": spec_delta["num_drafts"], + "num_draft_tokens": spec_delta["num_draft_tokens"], + "num_accepted_tokens": spec_delta["num_accepted_tokens"], + "avg_draft_acceptance_rate": acceptance_rate, + "mean_acceptance_length": mean_acceptance_length, + } + except Exception as e: + print(f"speculative decoding unavailable: {e}") + return output def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: diff --git a/verl/experimental/agent_loop/prometheus_utils.py b/verl/experimental/agent_loop/prometheus_utils.py index 0ce582df61e..167d2530791 100644 --- a/verl/experimental/agent_loop/prometheus_utils.py +++ b/verl/experimental/agent_loop/prometheus_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio import logging import os @@ -108,3 +108,46 @@ def reload_prometheus(port): except Exception as e: logger.error(f"Failed to update Prometheus configuration: {e}") + + +def _read_spec_decoding_metrics_from_prometheus_for_address(address: str) -> dict[str, float]: + import requests + from prometheus_client.parser import text_string_to_metric_families + + metric_name_to_key = { + "vllm:spec_decode_num_drafts_total": "num_drafts", + "vllm:spec_decode_num_draft_tokens_total": "num_draft_tokens", + "vllm:spec_decode_num_accepted_tokens_total": "num_accepted_tokens", + } + totals = {key: 0.0 for key in metric_name_to_key.values()} + session = requests.Session() + session.trust_env = False + + metrics_text = session.get(f"http://{address}/metrics", timeout=5).text + for family in text_string_to_metric_families(metrics_text): + for sample in family.samples: + key = metric_name_to_key.get(sample.name) + if key is not None: + totals[key] += float(sample.value) + return totals + + +async def read_spec_decoding_metrics_from_prometheus(server_adresses: list[str]) -> dict[str, float]: + totals = { + "num_drafts": 0.0, + "num_draft_tokens": 0.0, + "num_accepted_tokens": 0.0, + } + + results = await asyncio.gather( + *[ + asyncio.to_thread(_read_spec_decoding_metrics_from_prometheus_for_address, address) + for address in server_adresses + ] + ) + + for metrics in results: + for key, value in metrics.items(): + totals[key] += value + + return totals diff --git a/verl/trainer/config/_generated_diffusion_trainer.yaml b/verl/trainer/config/_generated_diffusion_trainer.yaml index 3f9ef3e41c8..10640606970 100644 --- a/verl/trainer/config/_generated_diffusion_trainer.yaml +++ b/verl/trainer/config/_generated_diffusion_trainer.yaml @@ -348,6 +348,14 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + speculative_decoding: + _target_: verl.workers.config.rollout.SpeculativeDecodingConfig + enable: false + method: eagle3 + num_steps: 1 + num_draft_tokens: 4 + draft_model_path: null + draft_tensor_parallel_size: 1 qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} height: 512 width: 512 diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index e6fc322e021..aa74bb5cfd7 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -368,6 +368,14 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + speculative_decoding: + _target_: verl.workers.config.rollout.SpeculativeDecodingConfig + enable: false + method: eagle3 + num_steps: 1 + num_draft_tokens: 4 + draft_model_path: null + draft_tensor_parallel_size: 1 qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} model: _target_: verl.workers.config.HFModelConfig diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml index a37acd9abe8..a6b8dd9ecf1 100644 --- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -335,6 +335,14 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + speculative_decoding: + _target_: verl.workers.config.rollout.SpeculativeDecodingConfig + enable: false + method: eagle3 + num_steps: 1 + num_draft_tokens: 4 + draft_model_path: null + draft_tensor_parallel_size: 1 qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} model: _target_: verl.workers.config.HFModelConfig diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 5009fbaaf85..9d2ed2858e7 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -344,6 +344,14 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + speculative_decoding: + _target_: verl.workers.config.rollout.SpeculativeDecodingConfig + enable: false + method: eagle3 + num_steps: 1 + num_draft_tokens: 4 + draft_model_path: null + draft_tensor_parallel_size: 1 qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} model: _target_: verl.workers.config.HFModelConfig diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index a44d2e2140e..a30d1c012b2 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -314,6 +314,14 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + speculative_decoding: + _target_: verl.workers.config.rollout.SpeculativeDecodingConfig + enable: false + method: eagle3 + num_steps: 1 + num_draft_tokens: 4 + draft_model_path: null + draft_tensor_parallel_size: 1 qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} model: _target_: verl.workers.config.HFModelConfig diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 1c4b99c416c..c2c56c10e93 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -424,5 +424,29 @@ quantization_config_file: null # MTP configuration, reuse model configuration mtp: ${oc.select:actor_rollout_ref.model.mtp, null} +# Speculative decoding configuration for vLLM rollout using an external draft model. +speculative_decoding: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.rollout.SpeculativeDecodingConfig + + # Whether to enable inference-only speculative decoding for vLLM rollout + enable: False + + # Speculative decoding method supported by vLLM, e.g. eagle or eagle3 + method: eagle3 + + # Number of speculative decoding steps + num_steps: 1 + + # Number of draft tokens proposed by the draft model + num_draft_tokens: 4 + + # Path to the draft model used for speculative decoding + draft_model_path: null + + # Tensor parallel size for the draft model, should be 1 or match tensor_model_parallel_size + draft_tensor_parallel_size: 1 + # QAT configuration (inherited from actor's engine config) qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 4dd7d2d00a5..f0538a5af5a 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -222,6 +222,24 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, metrics["tool_call_counts/max"] = tool_call_counts.max() metrics["tool_call_counts/mean"] = tool_call_counts.mean() + # speculative decoding + if "speculative_decoding_metrics" in batch.meta_info: + metrics["speculative_decoding/num_drafts"] = np.mean( + batch.meta_info["speculative_decoding_metrics"]["num_drafts"] + ) + metrics["speculative_decoding/num_draft_tokens"] = np.mean( + batch.meta_info["speculative_decoding_metrics"]["num_draft_tokens"] + ) + metrics["speculative_decoding/num_accepted_tokens"] = np.mean( + batch.meta_info["speculative_decoding_metrics"]["num_accepted_tokens"] + ) + metrics["speculative_decoding/avg_draft_acceptance_rate"] = np.mean( + batch.meta_info["speculative_decoding_metrics"]["avg_draft_acceptance_rate"] + ) + metrics["speculative_decoding/mean_acceptance_length"] = np.mean( + batch.meta_info["speculative_decoding_metrics"]["mean_acceptance_length"] + ) + return metrics diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 4075c994af3..4f89cdc452b 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -33,6 +33,7 @@ "RolloutConfig", "DiffusionRolloutConfig", "CheckpointEngineConfig", + "SpeculativeDecodingConfig", "SkipConfig", ] @@ -165,6 +166,18 @@ class CheckpointEngineConfig(BaseConfig): custom_backend_module: Optional[str] = None +@dataclass +class SpeculativeDecodingConfig(BaseConfig): + enable: bool = False + + method: str = "eagle3" + num_steps: int = 3 + num_draft_tokens: int = 4 + draft_model_path: str | None = None + + draft_tensor_parallel_size: int = 1 + + @dataclass class RolloutConfig(BaseConfig): _mutable_fields = { @@ -285,6 +298,8 @@ class RolloutConfig(BaseConfig): mtp: MtpConfig = field(default_factory=MtpConfig) + speculative_decoding: SpeculativeDecodingConfig = field(default_factory=SpeculativeDecodingConfig) + qat: Optional[dict] = None def __post_init__(self): @@ -334,6 +349,46 @@ def __post_init__(self): f"Current rollout {self.name=} not implemented pipeline_model_parallel_size > 1 yet." ) + if self.name != "vllm" and self.speculative_decoding.enable: + raise NotImplementedError( + f"Rollout {self.name=} does not support speculative decoding " + f"{self.speculative_decoding.method=} for rollout acceleration yet" + ) + + if self.name == "vllm" and self.speculative_decoding.enable: + if self.speculative_decoding.method.lower() not in {"eagle", "eagle3"}: + warnings.warn( + "Speculative decoding methods other than 'eagle' and 'eagle3' are untested and may be buggy ", + stacklevel=2, + ) + + if not ( + self.speculative_decoding.draft_tensor_parallel_size == self.tensor_model_parallel_size + or self.speculative_decoding.draft_tensor_parallel_size == 1 + ): + raise ValueError( + f"draft_tensor_parallel_size={self.speculative_decoding.draft_tensor_parallel_size} " + "cannot be other value than 1 or target model " + "tensor_parallel_size={self.tensor_model_parallel_size} " + ) + + if self.speculative_decoding.method.lower() in {"eagle", "eagle3"} and ( + self.enable_chunked_prefill or self.enable_prefix_caching or not self.enforce_eager + ): + warnings.warn( + "vLLM speculative decoding with EAGLE/EAGLE3 may regress throughput when " + "enable_chunked_prefill=True, enable_prefix_caching=True, or enforce_eager=False. " + "Overriding to enable_chunked_prefill=False, enable_prefix_caching=False, " + "and enforce_eager=True for now.", + stacklevel=2, + ) + self.enable_chunked_prefill = False + self.enable_prefix_caching = False + self.enforce_eager = True + + if self.speculative_decoding.enable and self.mtp.enable_rollout: + raise ValueError("Use either speculative_decoding or mtp, but not both simultaneously") + @dataclass class DiffusionRolloutConfig(RolloutConfig): diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index b2cc699f6ab..12bf30b17ce 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -176,7 +176,13 @@ def monkey_patch_model(self, vocab_size: int): # patch weight loader to support MoE model patch_vllm_moe_model_weight_loader(self.model_runner.model) - def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False, use_shm: bool = False): + def update_weights_from_ipc( + self, + peft_config: dict = None, + base_sync_done=False, + use_shm: bool = False, + use_speculative_decoding: bool = False, + ): """Update the weights of the rollout model.""" from vllm.platforms import current_platform @@ -216,10 +222,27 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False ) receiver.receive_weights( on_bucket_received=lambda weights: self._update_weights( - weights, peft_config=peft_config, base_sync_done=base_sync_done + weights, + peft_config=peft_config, + base_sync_done=base_sync_done, ) ) + if use_speculative_decoding: + # Reload draft weights because they are discarded after each model load. + from vllm.model_executor.model_loader import get_model_loader + + loader = get_model_loader(self.model_runner.drafter.vllm_config.load_config) + self.model_runner.drafter.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model_runner.drafter.model, + ) + ) + + # Rebuild RoPE caches because reloading weights clears the cos/sin cache. + rebuild_rope_caches(self.model_runner.drafter.model) + if self._is_qat_model: # QAT (compressed-tensors): call process_weights_after_loading AFTER all buckets are received from verl.utils.qat import manual_process_weights_after_loading @@ -239,6 +262,11 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False model_config = self.model_runner.vllm_config.model_config process_weights_after_loading(model, model_config, self.device) + if use_speculative_decoding: + drafter_model = self.model_runner.drafter.model + drafter_model_config = self.model_runner.drafter.vllm_config.model_config + process_weights_after_loading(drafter_model, drafter_model_config, self.device) + def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: dict, base_sync_done: bool): if peft_config and base_sync_done: weights = dict(weights) @@ -292,7 +320,13 @@ def __new__(cls, **kwargs): return super().__new__(cls) - def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False, use_shm: bool = False): + def update_weights_from_ipc( + self, + peft_config: dict = None, + base_sync_done=False, + use_shm: bool = False, + use_speculative_decoding: bool = False, + ): """Update the weights of the rollout model.""" from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import BucketedWeightReceiver @@ -428,3 +462,13 @@ def extract_prompt_logprobs(output: RequestOutput, num_prompt_logprobs: Optional result_dict["prompt_ids"] = prompt_ids_ls result_dict["prompt_logprobs"] = prompt_logprobs_ls + + +@torch.no_grad() +def rebuild_rope_caches(root_module: torch.nn.Module): + for _, m in root_module.named_modules(): + if hasattr(m, "rotary_emb"): + old = m.rotary_emb.cos_sin_cache + cache = m.rotary_emb._compute_cos_sin_cache() + cache = cache.to(device=old.device, dtype=old.dtype) + m.rotary_emb.cos_sin_cache = cache diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index f7c1164e2d8..bfa8dd5d7a4 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -290,6 +290,23 @@ async def launch_server(self, master_address: str = None, master_port: int = Non } args["speculative_config"] = speculative_config + # speculative decoding: + if self.config.speculative_decoding.enable: + if self.config.speculative_decoding.draft_model_path is None: + raise ValueError( + "self.config.speculative_decoding._draft_model_path shoul not be None when using with vLLM" + ) + + speculative_config = { + "model": self.config.speculative_decoding.draft_model_path, + "max_model_len": self.config.max_model_len, + "num_speculative_tokens": self.config.speculative_decoding.num_draft_tokens, + "method": self.config.speculative_decoding.method.lower(), + "draft_tensor_parallel_size": self.config.speculative_decoding.draft_tensor_parallel_size, + } + + args["speculative_config"] = speculative_config + if self.config.data_parallel_size > 1: assert self.gpus_per_node % self.config.tensor_model_parallel_size == 0, ( "gpus_per_node should be divisible by tensor_model_parallel_size" diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 90641b575a2..0a854df3c7a 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -98,6 +98,7 @@ def __init__( self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock" self.use_shm = not is_support_ipc() + self.use_speculative_decoding = config.speculative_decoding.enable if self.use_shm: logger.warning( "IPC is not supported on your devices. Falling back to shared memory for weight transfer, " @@ -160,7 +161,7 @@ async def update_weights( future = await self._execute_method( "update_weights_from_ipc", non_block=True, - kwargs={**kwargs, "use_shm": self.use_shm}, + kwargs={**kwargs, "use_shm": self.use_shm, "use_speculative_decoding": self.use_speculative_decoding}, ) bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes