Skip to content

Commit b9346fd

Browse files
committed
Fix gamma/lam issue
1 parent 0d5bad2 commit b9346fd

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

trinity/common/verl_config.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,8 @@ class KL_Ctrl:
182182

183183
@dataclass
184184
class Algorithm:
185-
gamma: float = 1.0
186-
lam: float = 1.0
187185
adv_estimator: str = "gae"
188-
# TODO (yanxi): remove the above advantage-related parameters?
186+
# TODO (yanxi): might remove adv_estimator completely, use AlgorithmConfig.advantage_fn_type instead
189187
norm_adv_by_std_in_grpo: bool = True
190188
use_kl_in_reward: bool = False
191189
kl_penalty: str = "kl"
@@ -316,20 +314,17 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
316314
self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio
317315

318316
# Algorithm related config
319-
if config.algorithm.gamma is not None:
320-
self.algorithm.gamma = config.algorithm.gamma
321-
if config.algorithm.lam is not None:
322-
self.algorithm.lam = config.algorithm.lam
323317
self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type
324318
if config.algorithm.algorithm_type == AlgorithmType.PPO:
325319
logger.info("Setting `adv_estimator` to 'gae' for PPO")
326320
self.algorithm.adv_estimator = AdvantageEstimator.GAE.value
327321
elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD):
328322
logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD")
329323
self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value
330-
# TODO (yanxi): it seems that adv_estimator only affects whether use_critic is set to
331-
# True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper);
332-
# need to double check whether this is indeed the case.
324+
# TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to
325+
# True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper).
326+
# Need to double check whether this is indeed the case,
327+
# and see if adv_estimator can be removed completely.
333328

334329
if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO
335330
if not self.actor_rollout_ref.actor.use_kl_loss:

trinity/trainer/verl/core_algos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def compute_gae_advantage_return(
139139
token_level_rewards: torch.Tensor,
140140
values: torch.Tensor,
141141
eos_mask: torch.Tensor,
142-
gamma: torch.Tensor,
143-
lam: torch.Tensor,
142+
gamma: float,
143+
lam: float,
144144
):
145145
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
146146
@@ -283,7 +283,7 @@ def compute_rloo_outcome_advantage(
283283

284284

285285
def compute_reinforce_plus_plus_outcome_advantage(
286-
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
286+
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: float
287287
):
288288
"""
289289
Compute advantage for REINFORCE++.

0 commit comments

Comments
 (0)