Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/e2e_ppo_trainer_megatron_vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ jobs:
- name: Install the current repository
run: |
pip3 install -r requirements-test.txt
pip3 install --no-deps -e .
pip3 install --no-deps --force-reinstall .
pip3 install mbridge
pip3 install math-verify
- name: Prepare GSM8K dataset
Expand All @@ -145,8 +145,8 @@ jobs:
- name: clean up and install Megatron-Bridge
run: |
rm -rf checkpoints
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@953aabf --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@2d398b4 --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@550924c --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@5455f0a --no-deps --no-build-isolation
pip3 install "nvidia-modelopt[torch]>=0.37.0" transformers==4.57.1
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron, use Megatron-Bridge LoRA e2e to pre-load and save (Deepseek)
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/e2e_ppo_trainer_megatron_vllm_2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ jobs:
- name: Install the current repository
run: |
pip3 install -r requirements-test.txt
pip3 install --no-deps -e .
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@953aabf --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@2d398b4 --no-deps --no-build-isolation
pip3 install --no-deps --force-reinstall .
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@550924c --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@5455f0a --no-deps --no-build-isolation
pip3 install "nvidia-modelopt[torch]>=0.37.0" transformers==4.57.1
- name: Prepare GSM8K dataset
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/advance/ppo_lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Megatron Backend Usage Guide

You need to install and enable Megatron-Bridge for Megatron LoRA support.

Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `this commit <https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/953aabf75c0500180dc14a6a76cf9e7e7c4baec7>`_ or later for proper support, and use the following settings to enable Megatron-Bridge:
Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `this commit <https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/550924c04368a175ef261a72230204410f455260>`_ or later for proper support, and use the following settings to enable Megatron-Bridge:

- ``actor_rollout_ref.actor.megatron.use_mbridge=True``
- ``actor_rollout_ref.actor.megatron.vanilla_mbridge=False``
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -xeuo pipefail

# Need to install Megatron-Bridge
# NOTE: Make sure you use Megatron-Bridge later than 0.2.0
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/953aabf75c0500180dc14a6a76cf9e7e7c4baec7 or later)
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/550924c04368a175ef261a72230204410f455260 or later)
# for proper MoE LoRA support.

# For Megatron communication/computation overlapping
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -xeuo pipefail

# Need to install Megatron-Bridge
# NOTE: Make sure you use Megatron-Bridge later than 0.2.0
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/953aabf75c0500180dc14a6a76cf9e7e7c4baec7 or later)
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/550924c04368a175ef261a72230204410f455260 or later)
# for proper MoE LoRA support.

# For Megatron communication/computation overlapping
Expand Down
5 changes: 4 additions & 1 deletion verl/experimental/fully_async_policy/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ def __init__(
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device

lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
self.ref_in_actor = lora_rank > 0

# define in-reward KL control
# kl loss control currently not suppoorted
Expand Down
15 changes: 6 additions & 9 deletions verl/experimental/one_step_off_policy/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
)
from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics
from verl.trainer.ppo.ray_trainer import (
RayPPOTrainer,
ResourcePoolManager,
Expand All @@ -54,9 +50,7 @@
from verl.utils import omega_conf_to_dataclass
from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
from verl.utils.debug import marked_timer
from verl.utils.metric import (
reduce_metrics,
)
from verl.utils.metric import reduce_metrics
from verl.utils.tracking import ValidationGenerationsLogger


Expand Down Expand Up @@ -119,8 +113,11 @@ def __init__(
self.device_name = device_name if device_name else self.config.trainer.device
self.validation_generations_logger = ValidationGenerationsLogger()

lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
self.ref_in_actor = lora_rank > 0

# define in-reward KL control
# kl loss control currently not suppoorted
Expand Down
36 changes: 9 additions & 27 deletions verl/experimental/transfer_queue/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@

from verl import DataProto
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.single_controller.ray import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
)
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.config import AlgoConfig
from verl.trainer.ppo import core_algos
Expand All @@ -64,33 +60,16 @@
process_validation_metrics,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import (
Role,
WorkerType,
need_critic,
need_reference_policy,
need_reward_model,
)
from verl.utils.checkpoint.checkpoint_manager import (
find_latest_ckpt_path,
should_save_ckpt_esi,
)
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
from verl.utils.metric import reduce_metrics
from verl.utils.rollout_skip import RolloutSkip
from verl.utils.seqlen_balancing import (
calculate_workload,
get_seqlen_balanced_partitions,
log_seqlen_unbalance,
)
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
from verl.utils.transferqueue_utils import (
create_transferqueue_client,
get_transferqueue_client,
tqbridge,
)
from verl.utils.transferqueue_utils import create_transferqueue_client, get_transferqueue_client, tqbridge


@dataclass
Expand Down Expand Up @@ -400,8 +379,11 @@ def __init__(
experiment_name=self.config.trainer.experiment_name,
)

lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
self.ref_in_actor = lora_rank > 0

# define in-reward KL control
# kl loss control currently not suppoorted
Expand Down
8 changes: 4 additions & 4 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,10 @@ def __init__(
)

# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = (
config.actor_rollout_ref.model.get("lora_rank", 0) > 0
or config.actor_rollout_ref.model.get("lora_adapter_path") is not None
)
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None

# define in-reward KL control
# kl loss control currently not suppoorted
Expand Down
23 changes: 18 additions & 5 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
except ImportError:
repatch = None

from contextlib import nullcontext

from megatron.core import parallel_state as mpu

from verl import DataProto
Expand Down Expand Up @@ -819,6 +821,10 @@ def generate_sequences(self, prompts: DataProto):
@GPUMemoryLogger(role="compute_ref_log_prob", logger=logger)
@DistProfiler.annotate(color="olive", role="ref_compute_log_prob")
def compute_ref_log_prob(self, data: DataProto):
if self.peft_cls is not None:
# if is lora, actor without lora applied is the ref
data.meta_info["is_lora"] = True
return self.compute_log_prob(data)
assert self._is_ref
if self._ref_is_offload_param:
load_megatron_model_to_gpu(self.ref_module, load_grad=False)
Expand All @@ -845,10 +851,13 @@ def compute_log_prob(self, data: DataProto):
if self._is_offload_param:
load_megatron_model_to_gpu(self.actor_module, load_grad=False)
log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger)
is_lora = data.meta_info.pop("is_lora", False)
adapter_ctx = self.peft_cls.disable_adapter(self.actor_module) if is_lora else nullcontext()
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
config_source = self.config.ref if is_lora else self.config.rollout
data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature

if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2":
Expand All @@ -857,9 +866,13 @@ def compute_log_prob(self, data: DataProto):
if self.enable_routing_replay and self.config.actor.router_replay.mode == "R3":
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)

output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=True)
with adapter_ctx:
output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output}
if not is_lora:
tensors["entropys"] = entropys
output = DataProto.from_dict(
tensors={"old_log_probs": output, "entropys": entropys},
tensors=tensors,
meta_info={"temperature": self.config.rollout.temperature},
)
if self.config.actor.router_replay.mode == "R2":
Expand Down