Skip to content

[megatron] feat: enable Megatron FSDP for SFT training#5854

Open
yxs wants to merge 2 commits intoverl-project:mainfrom
yxs:feat/megatron-fsdp-enabling
Open

[megatron] feat: enable Megatron FSDP for SFT training#5854
yxs wants to merge 2 commits intoverl-project:mainfrom
yxs:feat/megatron-fsdp-enabling

Conversation

@yxs
Copy link
Copy Markdown
Collaborator

@yxs yxs commented Apr 1, 2026

What does this PR do?

Enable Megatron-LM's native FullyShardedDataParallel (FSDP) in verl's Megatron engine for SFT training. This allows ZeRO-style parameter/gradient/optimizer state sharding across data-parallel ranks, reducing per-GPU memory usage for large model training.

Related issue: #5836 (Q2 Roadmap — Megatron FSDP enabling)

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

Tested on 8×H100 80GB with Qwen3-1.7B-Base, GSM8K SFT dataset, 1 epoch (77 steps).

DDP baseline vs FSDP ZeRO-3:

Step 1 loss Step 1 grad_norm Step 77 loss Step 77 grad_norm
DDP 0.7850 9.329 0.3965 1.699
FSDP ZeRO-3 0.7850 9.328 0.3963 1.697

API and Usage Example

torchrun --nproc_per_node=8 --nnodes=1 -m verl.trainer.sft_trainer \
      engine=megatron \
      engine.use_mbridge=True \
      engine.use_megatron_fsdp=True \
      engine.megatron_fsdp_zero_stage=3 \
      engine.tensor_model_parallel_size=1 \
      engine.pipeline_model_parallel_size=1 \
      engine.dtype=bfloat16 \
      model.path=<your_model_path> \
      data.train_files=<your_data.parquet> \
      data.train_batch_size=96 \
      data.micro_batch_size_per_gpu=2 \
      data.max_length=2048 \
      optim=megatron \
      optim.lr=2e-5 \
      trainer.total_epochs=1

New config fields in engine:

  • use_megatron_fsdp (bool, default False) — enable FSDP
  • megatron_fsdp_zero_stage (int, default 3) — 0/1/2/3 maps to no_shard/optim/optim_grads/optim_grads_params
  • megatron_fsdp_overlap_grad_reduce (bool, default True) — overlap grad reduce-scatter with backward
  • megatron_fsdp_overlap_param_gather (bool, default True) — overlap param all-gather with forward

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for Megatron FSDP (ZeRO-style sharding) by adding configuration fields, updating checkpointing to skip unsupported states, and introducing deferred model wrapping logic. Feedback points out a logic error in make_megatron_module where FSDP wrapping could be bypassed if wrap_with_ddp is disabled, recommending a refactor to ensure the configuration is correctly initialized when FSDP is enabled.

Comment on lines 328 to +356
ddp_config = None
if wrap_config.wrap_with_ddp:
ddp_config_dict = {
"use_distributed_optimizer": wrap_config.use_distributed_optimizer,
}
if override_ddp_config is not None:
ddp_config_dict.update(override_ddp_config)
ddp_config = ddp_config_dict

model = bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=wrap_config.wrap_with_ddp,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
ddp_config=ddp_config,
)
ddp_config = _build_ddp_config_dict(wrap_config, override_ddp_config)

use_fsdp = hasattr(wrap_config, "use_megatron_fsdp") and wrap_config.use_megatron_fsdp
if use_fsdp and not HAVE_MEGATRON_FSDP:
raise ImportError(
"engine.use_megatron_fsdp=True requires megatron-fsdp package. "
"Install from Megatron-LM dev branch with FSDP support."
)
if use_fsdp and wrap_config.wrap_with_ddp:
# FSDP wrapping deferred to after weight loading (mbridge can't parse FSDP structure)
model = bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=False,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
ddp_config=None,
)
pending_fsdp_config = ddp_config
else:
model = bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=wrap_config.wrap_with_ddp,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
ddp_config=ddp_config,
)
pending_fsdp_config = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential issue where FSDP is silently not applied in certain configurations. Specifically, if use_megatron_fsdp is true but wrap_with_ddp is false (e.g., in forward_only mode), the current logic will not build the FSDP configuration, and the model will not be wrapped with FSDP.

This can lead to unexpected behavior where FSDP is enabled in the config but not actually used.

To fix this, the logic should be adjusted to build the ddp_config if either DDP or FSDP is enabled, and then decide on wrapping based on whether FSDP is being used. This ensures FSDP is correctly applied.

            use_fsdp = hasattr(wrap_config, "use_megatron_fsdp") and wrap_config.use_megatron_fsdp
            ddp_config = None
            if wrap_config.wrap_with_ddp or use_fsdp:
                ddp_config = _build_ddp_config_dict(wrap_config, override_ddp_config)

            if use_fsdp and not HAVE_MEGATRON_FSDP:
                raise ImportError(
                    "engine.use_megatron_fsdp=True requires megatron-fsdp package. "
                    "Install from Megatron-LM dev branch with FSDP support."
                )

            if use_fsdp:
                # FSDP wrapping deferred to after weight loading (mbridge can't parse FSDP structure)
                model = bridge.get_model(
                    post_model_creation_callbacks=post_model_creation_callbacks,
                    wrap_with_ddp=False,
                    fp16=tf_config.fp16,
                    bf16=tf_config.bf16,
                    ddp_config=None,
                )
                pending_fsdp_config = ddp_config
            else:
                model = bridge.get_model(
                    post_model_creation_callbacks=post_model_creation_callbacks,
                    wrap_with_ddp=wrap_config.wrap_with_ddp,
                    fp16=tf_config.fp16,
                    bf16=tf_config.bf16,
                    ddp_config=ddp_config,
                )
                pending_fsdp_config = None

Copy link
Copy Markdown
Collaborator Author

@yxs yxs Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip FSDP when not training, wrap_with_ddp=False only applies to ref models and forward-only inference, which don't have backward/optimizer.

share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
wrap_with_ddp=True,
use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,
use_distributed_optimizer=megatron_config.use_distributed_optimizer,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

megatron_workers.py has been deprecated, please do not modify it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted

@wuxibin89 wuxibin89 requested a review from HollowMan6 April 2, 2026 00:42
yxs and others added 2 commits April 3, 2026 17:32
Enable Megatron-LM's native FullyShardedDataParallel in verl's Megatron
engine, allowing ZeRO-3 parameter/gradient/optimizer state sharding via
engine.use_megatron_fsdp=True. Uses deferred FSDP wrapping to maintain
compatibility with mbridge weight loading.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sync auto-generated config with new FSDP fields in megatron.yaml.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yxs yxs force-pushed the feat/megatron-fsdp-enabling branch from 4552820 to 3a6b47f Compare April 3, 2026 23:34
@wuxibin89
Copy link
Copy Markdown
Collaborator

@yxs CI failed, please fix it

@yxs
Copy link
Copy Markdown
Collaborator Author

yxs commented Apr 7, 2026

@wuxibin89 All 6 CI failures are unrelated to this PR.

  • megatron-deepseek, megatron-qwen3, megatron-moe-expert-parallel (5 failures): ValueError: too many values to unpack (expected 2), all triggered in the deprecated legacy worker path (use_legacy_worker_impl=enable). Our PR changed the return signature in megatron_utils.py but did not update the deprecated megatron_workers.py per your instruction.
  • e2e_one_step_off_policy_megatron_ascend: Ascend NPU environment issue RuntimeError: NPU out of memory. Tried to allocate 1.00 GiB, not related to our changes.
  • e2e_grpo_trainer_fsdp-qwen2, megatron-vlm, trtllm_unit_tests: Cancelled

@wuxibin89
Copy link
Copy Markdown
Collaborator

wuxibin89 commented Apr 8, 2026

Hold until ci with new engine work pass(use_legacy_worker_impl=disable) #5909

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants