Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ actor_rollout_ref:
dist_ckpt_optim_fully_reshardable: false
distrib_optim_fully_reshardable_mem_efficient: false
seed: 42
use_megatron_fsdp: false
megatron_fsdp_zero_stage: 3
megatron_fsdp_overlap_grad_reduce: true
megatron_fsdp_overlap_param_gather: true
override_ddp_config: {}
override_transformer_config:
recompute_granularity: null
Expand Down Expand Up @@ -221,6 +225,10 @@ actor_rollout_ref:
dist_ckpt_optim_fully_reshardable: false
distrib_optim_fully_reshardable_mem_efficient: false
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
use_megatron_fsdp: false
megatron_fsdp_zero_stage: 3
megatron_fsdp_overlap_grad_reduce: true
megatron_fsdp_overlap_param_gather: true
override_ddp_config: {}
override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}
override_mcore_model_config: {}
Expand Down Expand Up @@ -512,6 +520,10 @@ critic:
dist_ckpt_optim_fully_reshardable: false
distrib_optim_fully_reshardable_mem_efficient: false
seed: 42
use_megatron_fsdp: false
megatron_fsdp_zero_stage: 3
megatron_fsdp_overlap_grad_reduce: true
megatron_fsdp_overlap_param_gather: true
override_ddp_config: {}
override_transformer_config:
recompute_granularity: null
Expand Down
9 changes: 9 additions & 0 deletions verl/trainer/config/engine/megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ distrib_optim_fully_reshardable_mem_efficient: False
# oc.select: default val for ref.megatron.seed
seed: 42

# Megatron FSDP (ZeRO-style sharding, replaces DDP when enabled)
use_megatron_fsdp: False
# ZeRO stage: 0=no_shard, 1=optim, 2=optim_grads, 3=optim_grads_params
megatron_fsdp_zero_stage: 3
# Overlap gradient reduce with backward computation
megatron_fsdp_overlap_grad_reduce: True
# Overlap parameter gather with forward computation
megatron_fsdp_overlap_param_gather: True

# Allow to override Distributed Data Parallel (DDP) config
override_ddp_config: {}

Expand Down
20 changes: 15 additions & 5 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def generate_state_dict(
key = f"model{vpp_rank}" if len(self.model) > 1 else "model"
else:
key = "model"
if hasattr(model, "module"):
while hasattr(model, "module") and not hasattr(model, "sharded_state_dict"):
model = model.module

# GPTModel's sharded_state_dict function when having mtp requires metadata['dp_cp_group']
Expand All @@ -279,8 +279,9 @@ def generate_state_dict(
kwargs = {"metadata": model_metadata}
state_dict[key] = model.sharded_state_dict(**kwargs)

# Optimizer State Dict
if generate_optimizer:
# Optimizer State Dict (skip for FSDP — upstream sharding not yet supported)
is_fsdp = getattr(getattr(self.model[0], "ddp_config", None), "use_megatron_fsdp", False)
if generate_optimizer and not is_fsdp:
torch.distributed.barrier()
sharded_state_dict_kwargs = {"is_loading": is_loading}
if base_metadata is not None:
Expand Down Expand Up @@ -479,7 +480,15 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
self.bridge.load_hf_weights(self.model, hf_model_path)
log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger)

if self.should_load_optimizer:
is_fsdp = getattr(getattr(self.model[0], "ddp_config", None), "use_megatron_fsdp", False)
if self.should_load_optimizer and is_fsdp:
log_with_rank(
"Skipping optimizer state loading for Megatron FSDP (not yet supported). "
"Training will resume with fresh optimizer state.",
rank=self.rank,
logger=logger,
)
elif self.should_load_optimizer:
assert "optimizer" in state_dict, (
f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}."
)
Expand Down Expand Up @@ -582,7 +591,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
assert async_save_request is None, "Async save request should be None when not using async save."
torch.distributed.barrier()

if self.should_save_model:
is_fsdp = getattr(getattr(self.model[0], "ddp_config", None), "use_megatron_fsdp", False)
if self.should_save_model and not is_fsdp:
if self.use_hf_checkpoint:
# Use mbridge to save HF model checkpoint
log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger)
Expand Down
127 changes: 101 additions & 26 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@
from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import DistributedDataParallelConfig

try:
from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as MegatronFSDP
from megatron.core.distributed.fsdp.src.megatron_fsdp import MegatronFSDP as MegatronFSDPInner

HAVE_MEGATRON_FSDP = True
except ImportError:
MegatronFSDP = None
MegatronFSDPInner = None
HAVE_MEGATRON_FSDP = False
from megatron.core.enums import ModelType
from megatron.core.optimizer import ChainedOptimizer
from megatron.core.parallel_state import get_global_memory_buffer
Expand Down Expand Up @@ -209,6 +219,31 @@ class McoreModuleWrapperConfig:
share_embeddings_and_output_weights: bool = False
wrap_with_ddp: bool = True
use_distributed_optimizer: bool = True
# Megatron FSDP settings
use_megatron_fsdp: bool = False
megatron_fsdp_zero_stage: int = 3
megatron_fsdp_overlap_grad_reduce: bool = True
megatron_fsdp_overlap_param_gather: bool = True


def _build_ddp_config_dict(wrap_config: McoreModuleWrapperConfig, override_ddp_config: dict[str, Any] = None) -> dict:
ddp_config_dict = {
"use_distributed_optimizer": wrap_config.use_distributed_optimizer,
}
if getattr(wrap_config, "use_megatron_fsdp", False):
zero_stage_map = {0: "no_shard", 1: "optim", 2: "optim_grads", 3: "optim_grads_params"}
zero_stage = getattr(wrap_config, "megatron_fsdp_zero_stage", 3)
if zero_stage not in zero_stage_map:
raise ValueError(
f"Invalid megatron_fsdp_zero_stage={zero_stage}. Must be one of {list(zero_stage_map.keys())}."
)
ddp_config_dict["use_megatron_fsdp"] = True
ddp_config_dict["data_parallel_sharding_strategy"] = zero_stage_map[zero_stage]
ddp_config_dict["overlap_grad_reduce"] = getattr(wrap_config, "megatron_fsdp_overlap_grad_reduce", True)
ddp_config_dict["overlap_param_gather"] = getattr(wrap_config, "megatron_fsdp_overlap_param_gather", True)
if override_ddp_config is not None:
ddp_config_dict.update(override_ddp_config)
return ddp_config_dict


def make_megatron_module(
Expand All @@ -224,6 +259,7 @@ def make_megatron_module(
):
from verl.models.mcore.config_converter import get_hf_rope_theta

pending_fsdp_config = None
hf_config.rope_theta = get_hf_rope_theta(hf_config)

if override_model_config is None:
Expand Down Expand Up @@ -298,18 +334,12 @@ def peft_pre_wrap_hook(model):
for callback in post_model_creation_callbacks:
provider.register_pre_wrap_hook(callback)

# Create DDP config if needed
# Create DDP/FSDP config if needed
ddp_config = None
if wrap_config.wrap_with_ddp:
from megatron.bridge.training.config import DistributedDataParallelConfig

ddp_config_dict = {
"use_distributed_optimizer": wrap_config.use_distributed_optimizer,
}
# Apply any DDP config overrides
if override_ddp_config is not None:
ddp_config_dict.update(override_ddp_config)

ddp_config_dict = _build_ddp_config_dict(wrap_config, override_ddp_config)
ddp_config = DistributedDataParallelConfig(**ddp_config_dict)
ddp_config.finalize()

Expand All @@ -328,20 +358,33 @@ def peft_pre_wrap_hook(model):
# Build ddp_config dict with use_distributed_optimizer, same as provider path
ddp_config = None
if wrap_config.wrap_with_ddp:
ddp_config_dict = {
"use_distributed_optimizer": wrap_config.use_distributed_optimizer,
}
if override_ddp_config is not None:
ddp_config_dict.update(override_ddp_config)
ddp_config = ddp_config_dict

model = bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=wrap_config.wrap_with_ddp,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
ddp_config=ddp_config,
)
ddp_config = _build_ddp_config_dict(wrap_config, override_ddp_config)

use_fsdp = hasattr(wrap_config, "use_megatron_fsdp") and wrap_config.use_megatron_fsdp
if use_fsdp and not HAVE_MEGATRON_FSDP:
raise ImportError(
"engine.use_megatron_fsdp=True requires megatron-fsdp package. "
"Install from Megatron-LM dev branch with FSDP support."
)
if use_fsdp and wrap_config.wrap_with_ddp:
# FSDP wrapping deferred to after weight loading (mbridge can't parse FSDP structure)
model = bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=False,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
ddp_config=None,
)
pending_fsdp_config = ddp_config
else:
model = bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=wrap_config.wrap_with_ddp,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
ddp_config=ddp_config,
)
pending_fsdp_config = None
Comment on lines 359 to +387
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential issue where FSDP is silently not applied in certain configurations. Specifically, if use_megatron_fsdp is true but wrap_with_ddp is false (e.g., in forward_only mode), the current logic will not build the FSDP configuration, and the model will not be wrapped with FSDP.

This can lead to unexpected behavior where FSDP is enabled in the config but not actually used.

To fix this, the logic should be adjusted to build the ddp_config if either DDP or FSDP is enabled, and then decide on wrapping based on whether FSDP is being used. This ensures FSDP is correctly applied.

            use_fsdp = hasattr(wrap_config, "use_megatron_fsdp") and wrap_config.use_megatron_fsdp
            ddp_config = None
            if wrap_config.wrap_with_ddp or use_fsdp:
                ddp_config = _build_ddp_config_dict(wrap_config, override_ddp_config)

            if use_fsdp and not HAVE_MEGATRON_FSDP:
                raise ImportError(
                    "engine.use_megatron_fsdp=True requires megatron-fsdp package. "
                    "Install from Megatron-LM dev branch with FSDP support."
                )

            if use_fsdp:
                # FSDP wrapping deferred to after weight loading (mbridge can't parse FSDP structure)
                model = bridge.get_model(
                    post_model_creation_callbacks=post_model_creation_callbacks,
                    wrap_with_ddp=False,
                    fp16=tf_config.fp16,
                    bf16=tf_config.bf16,
                    ddp_config=None,
                )
                pending_fsdp_config = ddp_config
            else:
                model = bridge.get_model(
                    post_model_creation_callbacks=post_model_creation_callbacks,
                    wrap_with_ddp=wrap_config.wrap_with_ddp,
                    fp16=tf_config.fp16,
                    bf16=tf_config.bf16,
                    ddp_config=ddp_config,
                )
                pending_fsdp_config = None

Copy link
Copy Markdown
Collaborator Author

@yxs yxs Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip FSDP when not training, wrap_with_ddp=False only applies to ref models and forward-only inference, which don't have backward/optimizer.


if isinstance(tf_config, MLATransformerConfig):
# Keep the same behavior as hf_to_mcore_config_dpskv3
Expand All @@ -366,16 +409,40 @@ def megatron_model_provider(pre_process, post_process, vp_stage=None):
parallel_model.to(get_device_name())
return parallel_model

if getattr(wrap_config, "use_megatron_fsdp", False):
raise NotImplementedError(
"Megatron FSDP is only supported with mbridge (engine.use_mbridge=True). "
"Set engine.use_mbridge=True or disable FSDP with engine.use_megatron_fsdp=False."
)

model = get_model(
megatron_model_provider,
wrap_with_ddp=wrap_config.wrap_with_ddp,
use_distributed_optimizer=wrap_config.use_distributed_optimizer,
override_ddp_config=override_ddp_config,
)
return model, tf_config
return model, tf_config, pending_fsdp_config


def wrap_model_with_fsdp(model, fsdp_ddp_config_dict):
if not HAVE_MEGATRON_FSDP:
raise ImportError(
"Megatron FSDP requires megatron-fsdp package. "
"Install via: pip install megatron-core[fsdp] or install from Megatron-LM dev branch."
)
from megatron.core.distributed import DistributedDataParallelConfig as DDPConfig

ddp_config_obj = DDPConfig(**fsdp_ddp_config_dict)
if hasattr(ddp_config_obj, "finalize"):
ddp_config_obj.finalize()
config = get_model_config(model[0] if isinstance(model, list) else model)
if not isinstance(model, list):
model = [model]
return [MegatronFSDP(config=config, ddp_config=ddp_config_obj, module=m) for m in model]


ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
_fsdp_classes = tuple(c for c in (MegatronFSDP, MegatronFSDPInner) if c is not None)
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP,) + _fsdp_classes + (Float16Module,)


def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
Expand Down Expand Up @@ -1383,13 +1450,21 @@ def get_megatron_module_device(models: list[Any]) -> str:
return "cpu"

model_chunk = models[0]
if not model_chunk.buffers:
# FSDP wrapper: buffers is a method (nn.Module.buffers()), not a list attribute
buffers = getattr(model_chunk, "buffers", None)
if buffers is None or callable(buffers):
try:
return next(model_chunk.module.parameters()).device.type
except (StopIteration, AttributeError):
return "cpu"

if not buffers:
try:
return next(model_chunk.module.parameters()).device.type
except StopIteration:
return "cpu"

buffer = model_chunk.buffers[0]
buffer = buffers[0]
if buffer.param_data.storage().size() == 0:
return "cpu"
else:
Expand Down
5 changes: 5 additions & 0 deletions verl/workers/config/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ class McoreEngineConfig(EngineConfig):
dist_checkpointing_prefix: str = ""
dist_ckpt_optim_fully_reshardable: bool = False
distrib_optim_fully_reshardable_mem_efficient: bool = False
# Megatron FSDP (ZeRO-style sharding)
use_megatron_fsdp: bool = False
megatron_fsdp_zero_stage: int = 3 # 0=no_shard, 1=optim, 2=optim_grads, 3=optim_grads_params
megatron_fsdp_overlap_grad_reduce: bool = True
megatron_fsdp_overlap_param_gather: bool = True
override_ddp_config: dict[str, Any] = field(default_factory=dict)
override_transformer_config: dict[str, Any] = field(default_factory=dict)
override_mcore_model_config: dict[str, Any] = field(default_factory=dict)
Expand Down
13 changes: 11 additions & 2 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,15 @@ def _build_megatron_module(self):
is_value_model=self.is_value_model,
wrap_with_ddp=wrap_with_ddp,
use_distributed_optimizer=self.engine_config.use_distributed_optimizer,
use_megatron_fsdp=self.engine_config.use_megatron_fsdp,
megatron_fsdp_zero_stage=self.engine_config.megatron_fsdp_zero_stage,
megatron_fsdp_overlap_grad_reduce=self.engine_config.megatron_fsdp_overlap_grad_reduce,
megatron_fsdp_overlap_param_gather=self.engine_config.megatron_fsdp_overlap_param_gather,
)
if self.is_value_model:
self.model_config.hf_config.tie_word_embeddings = False

module, updated_tf_config = make_megatron_module(
module, updated_tf_config, pending_fsdp_config = make_megatron_module(
wrap_config=wrap_config,
tf_config=self.tf_config,
hf_config=self.model_config.hf_config,
Expand All @@ -280,7 +284,7 @@ def _build_megatron_module(self):
self.tf_config = updated_tf_config
print(f"module: {len(module)}")

if self.engine_config.use_dist_checkpointing:
if self.engine_config.use_dist_checkpointing and self.engine_config.dist_checkpointing_path:
load_mcore_dist_weights(
module, self.engine_config.dist_checkpointing_path, is_value_model=self.is_value_model
)
Expand All @@ -295,6 +299,11 @@ def _build_megatron_module(self):
module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params
)

if pending_fsdp_config is not None:
from verl.utils.megatron_utils import wrap_model_with_fsdp

module = wrap_model_with_fsdp(module, pending_fsdp_config)

if torch.distributed.get_rank() == 0:
print_model_size(module[0])

Expand Down
Loading