-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Fix two substantial edge cases in PPO's value target calculation #59958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix two substantial edge cases in PPO's value target calculation #59958
Conversation
…re regressions. Signed-off-by: Matthew <[email protected]>
Signed-off-by: Matthew <[email protected]>
…e target calculation Signed-off-by: Matthew <[email protected]>
Signed-off-by: Matthew <[email protected]>
Signed-off-by: Matthew <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses two significant edge cases in PPO's value target calculation, particularly when lambda is zero. The changes in rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py correctly adjust how terminateds flags are generated. This ensures that terminal rewards are properly included in value calculations and that truncated episodes are correctly handled for value bootstrapping. The addition of unit tests is a great way to prevent regressions. I've found a potential issue in the new test file where the test helper function seems to have a logic error in data preparation, which I've provided a suggestion to fix. Overall, the core logic fix is sound and this PR is a valuable improvement.
Signed-off-by: Matthew <[email protected]>
|
Thanks @MatthewCWeston, I'll try to have a look at this soon |
This bug seems to massively affect our performance. What's the timeline for getting the fix merged and released? Switching from truncated to complete episodes helps, but would be great to get the proper fix as soon as possible. |
|
@MatthewCWeston I generated a script that would let us review the batches passed to the learner to help identify if there are any issues. Batch 0 of 2
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Observation | 0.0 | 1.0 | 2.0 | 3.0 | 4.0 | 5.0 | 6.0 | 7.0 | 8.0 | 9.0 | 10.0 | 0.0 | 1.0 | 2.0 | 3.0 | 4.0 |
Reward | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 0.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 |
Terminated | F | F | F | F | F | F | F | F | F | T | T | F | F | F | F | F |
Truncated | F | F | F | F | F | F | F | F | F | F | F | F | F | F | F | F |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Advantage | 1.9252 | 1.6178 | 1.2847 | 0.9509 | 0.6153 | 0.2766 | -0.0654 | -0.4107 | -0.7594 | -1.1115 | -1.4670 | 0.6140 | 0.2933 | -0.0531 | -0.4004 | -0.7497 |
Value Target | 9.5618 | 8.6483 | 7.7255 | 6.7935 | 5.8520 | 4.9010 | 3.9404 | 2.9701 | 1.9900 | 1.0000 | 0.0000 | 5.8520 | 4.9010 | 3.9404 | 2.9701 | 1.9900 |
Value Pred | -0.0565 | -0.1002 | -0.0807 | -0.0683 | -0.0602 | -0.0529 | -0.0460 | -0.0394 | -0.0329 | -0.0266 | -0.0208 | -0.0565 | -0.1002 | -0.0807 | -0.0683 | -0.0602 |
VF Error | 92.5107 | 76.5359 | 60.9380 | 47.0845 | 34.9541 | 24.5412 | 15.8917 | 9.0569 | 4.0920 | 1.0540 | 0.0004 | 34.9097 | 25.0120 | 16.1696 | 9.2322 | 4.2034 |
Surrogate Loss | -1.9252 | -1.6178 | -1.2847 | -0.9509 | -0.6153 | -0.2766 | 0.0654 | 0.4107 | 0.7594 | 1.1115 | 1.4753 | -0.6140 | -0.2933 | 0.0531 | 0.4004 | 0.7497 |
Entropy | 0.6919 | 0.6930 | 0.6931 | 0.6925 | 0.6920 | 0.6918 | 0.6918 | 0.6919 | 0.6921 | 0.6923 | 0.6925 | 0.6919 | 0.6930 | 0.6931 | 0.6925 | 0.6920 |
Log Prob Ratio | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0056 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
In Loss | T | T | T | T | T | T | T | T | T | T | F | T | T | T | T | T |
Batch 1 of 2
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Observation | 5.0 | 6.0 | 0.0 | 1.0 | 2.0 | 3.0 | 4.0 | 5.0 | 6.0 | 7.0 | 8.0 | 9.0 | 10.0 | 0.0 | 1.0 | 2.0 |
Reward | 1.0000 | 0.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 1.0000 | 0.0000 | 1.0000 | 1.0000 | 1.0000 |
Terminated | F | T | F | F | F | F | F | F | F | F | F | T | T | F | F | F |
Truncated | F | T | F | F | F | F | F | F | F | F | F | F | F | F | F | F |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Advantage | -1.1022 | -1.4581 | 1.9252 | 1.6178 | 1.2847 | 0.9509 | 0.6153 | 0.2766 | -0.0654 | -0.4107 | -0.7594 | -1.1115 | -1.4670 | 0.6140 | 0.2933 | -0.0531 |
Value Target | 1.0000 | 0.0000 | 9.5618 | 8.6483 | 7.7255 | 6.7935 | 5.8520 | 4.9010 | 3.9404 | 2.9701 | 1.9900 | 1.0000 | 0.0000 | 5.8520 | 4.9010 | 3.9404 |
Value Pred | 0.0219 | 0.0305 | -0.0529 | -0.0630 | -0.0229 | -0.0010 | 0.0119 | 0.0219 | 0.0305 | 0.0383 | 0.0455 | 0.0522 | 0.0584 | -0.0529 | -0.0630 | -0.0229 |
VF Error | 0.9567 | 0.0009 | 92.4424 | 75.8860 | 60.0388 | 46.1647 | 34.1069 | 23.8057 | 15.2875 | 8.5957 | 3.7810 | 0.8982 | 0.0034 | 34.8677 | 24.6411 | 15.7081 |
Surrogate Loss | 1.0615 | 1.4009 | -1.9621 | -1.6139 | -1.3090 | -0.9235 | -0.6343 | -0.2663 | 0.0628 | 0.3942 | 0.7284 | 1.0656 | 1.4137 | -0.6257 | -0.2940 | 0.0541 |
Entropy | 0.6893 | 0.6891 | 0.6907 | 0.6930 | 0.6926 | 0.6912 | 0.6900 | 0.6893 | 0.6891 | 0.6893 | 0.6896 | 0.6899 | 0.6903 | 0.6907 | 0.6930 | 0.6926 |
Log Prob Ratio | 0.9630 | 0.9608 | 1.0192 | 0.9976 | 1.0189 | 0.9712 | 1.0309 | 0.9630 | 0.9612 | 0.9599 | 0.9592 | 0.9586 | 0.9637 | 1.0192 | 1.0025 | 1.0189 |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
In Loss | T | F | T | T | T | T | T | T | T | T | T | T | F | T | T | T |I suspect that there might be an issue on the truncated episode in the second batch observation = 6.0 as both terminated and truncated is True. What are your thoughts? Here is the code that I used """
The counting environment has:
- Observations from 0 to 10 (integer values as observation)
- Reward of 1.0 for every timestep
Note on PPO batch structure:
- PPO uses GAE (Generalized Advantage Estimation) for computing advantages and
value targets, so it doesn't use next_observations directly in the loss.
- The batch includes an artificial "bootstrap" timestep (with loss_mask=False)
added for value function computation at episode boundaries.
- Elements with loss_mask=False are excluded from the loss computation.
"""
import gymnasium as gym
import numpy as np
from typing import Any, Dict, Optional
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner
from ray.rllib.core.columns import Columns
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModuleID, TensorType
torch, nn = try_import_torch()
# ANSI color codes
GREEN = "\033[32m"
RED = "\033[31m"
RESET = "\033[0m"
class CountingEnv(gym.Env):
"""Simple counting environment for debugging PPO loss observation.
The environment counts from 0 to max_count (default 10):
- Observation: current count value (0-indexed)
- Reward: 1.0 for every timestep
- Episode ends when count reaches max_count
"""
def __init__(self, config: Optional[Dict] = None):
config = config or {}
self.max_count = config.get("max_count", 10)
self.use_termination = config.get("use_termination", True)
self.observation_space = gym.spaces.Box(
low=0, high=self.max_count, shape=(1,), dtype=np.float32
)
# (actions don't affect environment)
self.action_space = gym.spaces.Discrete(2)
self.count = 0
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self.count = 0
return np.array([self.count], dtype=np.float32), {}
def step(self, action):
self.count += 1
obs = np.array([self.count], dtype=np.float32)
reward = 1.0
terminated = self.use_termination and self.count >= self.max_count
truncated = (not self.use_termination) and self.count >= self.max_count
return obs, reward, terminated, truncated, {}
class PPOLossObservationLearner(PPOTorchLearner):
"""Custom PPO learner that captures per-element batch data for inspection.
Uses super() to delegate loss computation to PPOTorchLearner, then captures
batch data for observation. Access captured data via self.last_batch_info.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.all_batches = [] # List of all captured batches
@override(PPOTorchLearner)
def compute_loss_for_module(
self,
*,
module_id: ModuleID,
config: PPOConfig,
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType],
) -> TensorType:
# Delegate to parent for actual loss computation
total_loss = super().compute_loss_for_module(
module_id=module_id, config=config, batch=batch, fwd_out=fwd_out
)
# Capture batch data for observation (doesn't affect loss)
self._capture_batch_info(module_id, config, batch, fwd_out)
return total_loss
def _capture_batch_info(
self,
module_id: ModuleID,
config: PPOConfig,
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType],
) -> None:
"""Capture per-element batch data for observation."""
module = self.module[module_id].unwrapped()
# Get action distributions for computing per-element metrics
curr_dist = module.get_train_action_dist_cls().from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
# Compute per-element loss components
logp_ratio = torch.exp(
curr_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
)
advantages = batch[Postprocessing.ADVANTAGES]
surrogate_loss = torch.min(
advantages * logp_ratio,
advantages * torch.clamp(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
)
# Value function predictions and error
if config.use_critic:
vf_preds = module.compute_values(batch, embeddings=fwd_out.get(Columns.EMBEDDINGS))
vf_targets = batch[Postprocessing.VALUE_TARGETS]
vf_error = torch.pow(vf_preds - vf_targets, 2.0)
else:
vf_preds = vf_targets = vf_error = None
# Append batch info to list
self.all_batches.append({
"batch_size": batch[Columns.OBS].shape[0],
# Core batch data
"observations": batch[Columns.OBS].detach().cpu().numpy(),
"actions": batch[Columns.ACTIONS].detach().cpu().numpy(),
"rewards": batch[Columns.REWARDS].detach().cpu().numpy(),
"terminateds": batch[Columns.TERMINATEDS].detach().cpu().numpy(),
"truncateds": batch[Columns.TRUNCATEDS].detach().cpu().numpy(),
# GAE-computed values
"advantages": advantages.detach().cpu().numpy(),
"value_targets": vf_targets.detach().cpu().numpy(),
"value_predictions": vf_preds.detach().cpu().numpy(),
# Per-element loss components
"surrogate_loss": (-surrogate_loss).detach().cpu().numpy(),
"vf_error": vf_error.detach().cpu().numpy(),
"entropy": curr_dist.entropy().detach().cpu().numpy(),
"log_prob_ratio": logp_ratio.detach().cpu().numpy(),
# Mask (False = bootstrap timestep, excluded from loss)
"loss_mask": batch.get(Columns.LOSS_MASK).detach().cpu().numpy(),
})
def print_batch_info(self, batch_index: int = -1, max_elements: int = 16):
info = self.all_batches[batch_index]
n = min(max_elements, info["batch_size"])
batch_num = batch_index if batch_index >= 0 else len(self.all_batches) + batch_index
col_width = 10
label_width = 15
row_length = label_width + 1 + (col_width + 1) * n
# Count masked vs unmasked elements
if info["loss_mask"] is not None:
mask = info["loss_mask"]
else:
mask = np.ones(info["batch_size"], dtype=bool)
# Helper to format a row
def print_row(label: str, values):
row = f"{label:<{label_width}}|"
for i in range(n):
val = values[i]
if isinstance(val, (np.floating, float)):
formatted = f"{val:.4f}".center(col_width)
elif isinstance(val, np.ndarray):
if val.size == 1:
formatted = f"{val.item():.1f}".center(col_width)
else:
formatted = f"{val[0]:.1f}...".center(col_width)
elif isinstance(val, (bool, np.bool_)):
# Color: green for True, red for False
text = "T" if val else "F"
color = GREEN if val else RED
formatted = f"{color}{text.center(col_width)}{RESET}"
else:
formatted = str(val).center(col_width)
row += formatted + "|"
print(row)
print(f"Batch {batch_num} of {len(self.all_batches)}")
print("-" * row_length)
print_row("Observation", info["observations"])
# print_row("Action", info["actions"])
print_row("Reward", info["rewards"])
print_row("Terminated", info["terminateds"])
print_row("Truncated", info["truncateds"])
print("-" * row_length)
print_row("Advantage", info["advantages"])
print_row("Value Target", info["value_targets"])
if info["value_predictions"] is not None:
print_row("Value Pred", info["value_predictions"])
if info["vf_error"] is not None:
print_row("VF Error", info["vf_error"])
print_row("Surrogate Loss", info["surrogate_loss"])
print_row("Entropy", info["entropy"])
print_row("Log Prob Ratio", info["log_prob_ratio"])
print("-" * row_length)
print_row("In Loss", mask)
print(f"\n--- Summary (valid elements only) ---")
print(f"Mean Surrogate Loss: {np.mean(info['surrogate_loss'][mask]):.4f}")
valid_vf_error = info["vf_error"][mask] if info["vf_error"] is not None else None
if valid_vf_error is not None and len(valid_vf_error) > 0:
print(f"Mean VF Error: {np.mean(valid_vf_error):.4f}")
print(f"Mean Entropy: {np.mean(info['entropy'][mask]):.4f}")
print(f"Mean Advantage: {np.mean(info['advantages'][mask]):.4f}\n")
def print_all_batches(self, max_elements: int = 16):
"""Print all captured batches."""
for i in range(len(self.all_batches)):
self.print_batch_info(batch_index=i, max_elements=max_elements)
def run_ppo_with_loss_observation(
max_count: int = 10,
use_termination: bool = True,
max_elements: int = 16,
):
config = (
PPOConfig()
.environment(
CountingEnv,
env_config={
"max_count": max_count,
"use_termination": use_termination,
},
)
.env_runners(
num_env_runners=0,
num_envs_per_env_runner=1,
rollout_fragment_length=max_elements,
)
.training(
train_batch_size_per_learner=max_elements,
minibatch_size=max_elements,
num_epochs=1,
shuffle_batch_per_epoch=False,
gamma=0.99,
use_gae=True,
lambda_=1.0,
)
.learners(
num_learners=0,
learner_class=PPOLossObservationLearner,
)
)
# Build and train the algorithm
algo = config.build()
algo.train()
# Get the learner and print all batches
learner = algo.learner_group._learner
learner.print_all_batches(max_elements=max_elements)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Observe PPO loss function with counting environment"
)
parser.add_argument(
"--max-count",
type=int,
default=10,
help="Maximum count value (observations go from 0 to max_count)",
)
parser.add_argument(
"--use-termination",
type=lambda x: x.lower() in ("true", "1", "yes"),
default=True,
help="End episodes with terminated (True) or truncation (False)",
)
parser.add_argument(
"--max-elements",
type=int,
default=16,
help="Maximum number of batch elements to display in the table",
)
args = parser.parse_args()
run_ppo_with_loss_observation(
max_count=args.max_count,
use_termination=args.use_termination,
max_elements=args.max_elements,
) |
@pseudo-rnd-thoughts I took a manual look at the second batch of your results, and the terminated flag shouldn't be set there, but the results you posted appear to be for the production code rather than the code in this PR. With the PR code enabled, the episode is correctly marked as truncated and not terminated, and the value target calculations come out correctly. Namely, the episode that gets truncated after acting on 5.0 inherits the value prediction from 6.0, and the step that acts upon 5.0 gets a value target of I've made a notebook that allows this to be tested on Colab: |
If this is the case then the problem is not with the adding of the extra timestep but in the GAE implementation for handling truncated timesteps would be my interpretation. |
@pseudo-rnd-thoughts I think there might be a miscommunication - the value targets look exactly as they should to me, conditional on the termination/truncation values. The production code incorrectly marks the episode as terminated, causing the value head prediction to be discarded when it shouldn't be. The PR prevents this flag from being added unless the episode really is terminated, yielding a correct value target. |
|
Ah sorry, yes, I think I misunderstood. I've just spend 5 minutes working out why to realise that you were on the right lines the whole time. Apologies if I came across as an ass For the terminated episode case we lose the terminated from observation 9 (keeping it in 10) while for the truncated episode case we lose the terminated signal from 6. I think we want the second change but not the first however I need to double check that. EDIT: This is a middle ground solution that does both |
|
@pseudo-rnd-thoughts The first case, in batch one, correctly computes a value target of 1.00 for the last real timestep in both the production implementation and the PR. Both are correct, but the production code only gets the correct answer here because Setting the termination flag in the next-to-last timestep breaks propagation of terminal rewards when dealing with pure value bootstrapping (you can set Edit: Added results of testing with your script under the specified conditions: |
|
@MatthewCWeston I think your solution is correct, I just want to implement my own test to independently compute the expected loss and compare against the PPOLearner loss to double check that everything is working as expected. |
|
@MatthewCWeston I incorporated the fix in #59007 to run against a new suite of examples / testing that I've been building and Gemini raised a point about your change breaking VTrace (according to the docstring in AddOneTsToEpisodesAndTruncate line 95). |
|
@pseudo-rnd-thoughts I had a quick look at the VTrace code. I'm not intimately familiar with that algorithm, but the relevant logic is here (where discounts is set to zero on terminal states) and here (where discounts are applied to the next states' values when calculating advantage). Seeing as the tagging scheme that works depends on the advantage calculation scheme that an algorithm uses, with GAE requiring the one in this PR and VTrace (as far as I know) requiring the one in production, my first instinct is to have This amounts to changing line 511 of ...and adding an Seems like the cleanest set of changes that makes both algorithms work as intended. Thoughts? |


Description
AddOneTsToEpisodesAndTruncatecurrently has two edge cases which cause significant errors in value target calculation.The first is that, when lambda is set to zero (pure value bootstrapping), terminal rewards are discarded from value target calculation for earlier states. This is because a terminated flag is placed in the second-to-last timestep of the modified episodes, resulting in said rewards being masked out.
The second is that truncated episodes do not make use of value head calculations from the last observation, under the same conditions. This is because the terminated flag is appended to every episode, regardless of whether the episode actually terminates.
This PR fixes both issues and includes a pair of unit tests to prevent regression. One is a lightweight test of an environment whose state values can be quickly learned, which checks for both correct learning and plausible state values. The other directly evaluates the correctness of value target calculations on a toy dataset.
This Colab notebook allows both of the above bugs to be readily verified. The topmost cell will patch in the fix from this PR. Run it before the below cells to test the fix. Skip running it altogether to test the code that is currently in production.
Related issues
Resolves the bug outlined in #57683.
Resolves the additional issue raised in #57895.
Additional information
A prior PR submitted by me failed to address the second edge case, and was closed as a result. This PR should correctly resolve both of them.