Skip to content

Conversation

@MatthewCWeston
Copy link
Contributor

@MatthewCWeston MatthewCWeston commented Jan 8, 2026

Description

AddOneTsToEpisodesAndTruncate currently has two edge cases which cause significant errors in value target calculation.

  • The first is that, when lambda is set to zero (pure value bootstrapping), terminal rewards are discarded from value target calculation for earlier states. This is because a terminated flag is placed in the second-to-last timestep of the modified episodes, resulting in said rewards being masked out.

    • In effect, this means that, in an environment with terminal rewards, we only learn a policy for the very last timestep!
  • The second is that truncated episodes do not make use of value head calculations from the last observation, under the same conditions. This is because the terminated flag is appended to every episode, regardless of whether the episode actually terminates.

    • In effect, this means that truncation is treated identically to termination, except that the bug described in the prior bullet point does not occur, and value targets fail to include any of the potential future rewards after the truncation point.

This PR fixes both issues and includes a pair of unit tests to prevent regression. One is a lightweight test of an environment whose state values can be quickly learned, which checks for both correct learning and plausible state values. The other directly evaluates the correctness of value target calculations on a toy dataset.

This Colab notebook allows both of the above bugs to be readily verified. The topmost cell will patch in the fix from this PR. Run it before the below cells to test the fix. Skip running it altogether to test the code that is currently in production.

Related issues

Resolves the bug outlined in #57683.

Resolves the additional issue raised in #57895.

Additional information

A prior PR submitted by me failed to address the second edge case, and was closed as a result. This PR should correctly resolve both of them.

@MatthewCWeston MatthewCWeston requested a review from a team as a code owner January 8, 2026 02:28
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses two significant edge cases in PPO's value target calculation, particularly when lambda is zero. The changes in rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py correctly adjust how terminateds flags are generated. This ensures that terminal rewards are properly included in value calculations and that truncated episodes are correctly handled for value bootstrapping. The addition of unit tests is a great way to prevent regressions. I've found a potential issue in the new test file where the test helper function seems to have a logic error in data preparation, which I've provided a suggestion to fix. Overall, the core logic fix is sound and this PR is a valuable improvement.

@ray-gardener ray-gardener bot added rllib RLlib related issues community-contribution Contributed by the community labels Jan 8, 2026
@pseudo-rnd-thoughts
Copy link
Member

Thanks @MatthewCWeston, I'll try to have a look at this soon

@stefanbschneider
Copy link
Member

This is because the terminated flag is appended to every episode, regardless of whether the episode actually terminates.

This bug seems to massively affect our performance. What's the timeline for getting the fix merged and released?

Switching from truncated to complete episodes helps, but would be great to get the proper fix as soon as possible.

@pseudo-rnd-thoughts
Copy link
Member

@MatthewCWeston I generated a script that would let us review the batches passed to the learner to help identify if there are any issues.

Batch 0 of 2
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Observation    |   0.0    |   1.0    |   2.0    |   3.0    |   4.0    |   5.0    |   6.0    |   7.0    |   8.0    |   9.0    |   10.0   |   0.0    |   1.0    |   2.0    |   3.0    |   4.0    |
Reward         |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  0.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |
Terminated     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    T     |    T     |    F     |    F     |    F     |    F     |    F     |
Truncated      |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Advantage      |  1.9252  |  1.6178  |  1.2847  |  0.9509  |  0.6153  |  0.2766  | -0.0654  | -0.4107  | -0.7594  | -1.1115  | -1.4670  |  0.6140  |  0.2933  | -0.0531  | -0.4004  | -0.7497  |
Value Target   |  9.5618  |  8.6483  |  7.7255  |  6.7935  |  5.8520  |  4.9010  |  3.9404  |  2.9701  |  1.9900  |  1.0000  |  0.0000  |  5.8520  |  4.9010  |  3.9404  |  2.9701  |  1.9900  |
Value Pred     | -0.0565  | -0.1002  | -0.0807  | -0.0683  | -0.0602  | -0.0529  | -0.0460  | -0.0394  | -0.0329  | -0.0266  | -0.0208  | -0.0565  | -0.1002  | -0.0807  | -0.0683  | -0.0602  |
VF Error       | 92.5107  | 76.5359  | 60.9380  | 47.0845  | 34.9541  | 24.5412  | 15.8917  |  9.0569  |  4.0920  |  1.0540  |  0.0004  | 34.9097  | 25.0120  | 16.1696  |  9.2322  |  4.2034  |
Surrogate Loss | -1.9252  | -1.6178  | -1.2847  | -0.9509  | -0.6153  | -0.2766  |  0.0654  |  0.4107  |  0.7594  |  1.1115  |  1.4753  | -0.6140  | -0.2933  |  0.0531  |  0.4004  |  0.7497  |
Entropy        |  0.6919  |  0.6930  |  0.6931  |  0.6925  |  0.6920  |  0.6918  |  0.6918  |  0.6919  |  0.6921  |  0.6923  |  0.6925  |  0.6919  |  0.6930  |  0.6931  |  0.6925  |  0.6920  |
Log Prob Ratio |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0056  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
In Loss        |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    F     |    T     |    T     |    T     |    T     |    T     |

Batch 1 of 2
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Observation    |   5.0    |   6.0    |   0.0    |   1.0    |   2.0    |   3.0    |   4.0    |   5.0    |   6.0    |   7.0    |   8.0    |   9.0    |   10.0   |   0.0    |   1.0    |   2.0    |
Reward         |  1.0000  |  0.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  1.0000  |  0.0000  |  1.0000  |  1.0000  |  1.0000  |
Terminated     |    F     |    T     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    T     |    T     |    F     |    F     |    F     |
Truncated      |    F     |    T     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |    F     |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Advantage      | -1.1022  | -1.4581  |  1.9252  |  1.6178  |  1.2847  |  0.9509  |  0.6153  |  0.2766  | -0.0654  | -0.4107  | -0.7594  | -1.1115  | -1.4670  |  0.6140  |  0.2933  | -0.0531  |
Value Target   |  1.0000  |  0.0000  |  9.5618  |  8.6483  |  7.7255  |  6.7935  |  5.8520  |  4.9010  |  3.9404  |  2.9701  |  1.9900  |  1.0000  |  0.0000  |  5.8520  |  4.9010  |  3.9404  |
Value Pred     |  0.0219  |  0.0305  | -0.0529  | -0.0630  | -0.0229  | -0.0010  |  0.0119  |  0.0219  |  0.0305  |  0.0383  |  0.0455  |  0.0522  |  0.0584  | -0.0529  | -0.0630  | -0.0229  |
VF Error       |  0.9567  |  0.0009  | 92.4424  | 75.8860  | 60.0388  | 46.1647  | 34.1069  | 23.8057  | 15.2875  |  8.5957  |  3.7810  |  0.8982  |  0.0034  | 34.8677  | 24.6411  | 15.7081  |
Surrogate Loss |  1.0615  |  1.4009  | -1.9621  | -1.6139  | -1.3090  | -0.9235  | -0.6343  | -0.2663  |  0.0628  |  0.3942  |  0.7284  |  1.0656  |  1.4137  | -0.6257  | -0.2940  |  0.0541  |
Entropy        |  0.6893  |  0.6891  |  0.6907  |  0.6930  |  0.6926  |  0.6912  |  0.6900  |  0.6893  |  0.6891  |  0.6893  |  0.6896  |  0.6899  |  0.6903  |  0.6907  |  0.6930  |  0.6926  |
Log Prob Ratio |  0.9630  |  0.9608  |  1.0192  |  0.9976  |  1.0189  |  0.9712  |  1.0309  |  0.9630  |  0.9612  |  0.9599  |  0.9592  |  0.9586  |  0.9637  |  1.0192  |  1.0025  |  1.0189  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
In Loss        |    T     |    F     |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    T     |    F     |    T     |    T     |    T     |

I suspect that there might be an issue on the truncated episode in the second batch observation = 6.0 as both terminated and truncated is True.
However I still don't think the problem is with the extra timestep but I suspect the problem could be with the GAE implementation. I haven't check that yet.

What are your thoughts?

Here is the code that I used

"""
The counting environment has:
- Observations from 0 to 10 (integer values as observation)
- Reward of 1.0 for every timestep

Note on PPO batch structure:
- PPO uses GAE (Generalized Advantage Estimation) for computing advantages and
  value targets, so it doesn't use next_observations directly in the loss.
- The batch includes an artificial "bootstrap" timestep (with loss_mask=False)
  added for value function computation at episode boundaries.
- Elements with loss_mask=False are excluded from the loss computation.
"""
import gymnasium as gym
import numpy as np
from typing import Any, Dict, Optional

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner
from ray.rllib.core.columns import Columns
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModuleID, TensorType

torch, nn = try_import_torch()

# ANSI color codes
GREEN = "\033[32m"
RED = "\033[31m"
RESET = "\033[0m"


class CountingEnv(gym.Env):
    """Simple counting environment for debugging PPO loss observation.

    The environment counts from 0 to max_count (default 10):
    - Observation: current count value (0-indexed)
    - Reward: 1.0 for every timestep
    - Episode ends when count reaches max_count
    """

    def __init__(self, config: Optional[Dict] = None):
        config = config or {}
        self.max_count = config.get("max_count", 10)
        self.use_termination = config.get("use_termination", True)

        self.observation_space = gym.spaces.Box(
            low=0, high=self.max_count, shape=(1,), dtype=np.float32
        )
        # (actions don't affect environment)
        self.action_space = gym.spaces.Discrete(2)

        self.count = 0

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.count = 0
        return np.array([self.count], dtype=np.float32), {}

    def step(self, action):
        self.count += 1

        obs = np.array([self.count], dtype=np.float32)
        reward = 1.0
        terminated = self.use_termination and self.count >= self.max_count
        truncated = (not self.use_termination) and self.count >= self.max_count

        return obs, reward, terminated, truncated, {}


class PPOLossObservationLearner(PPOTorchLearner):
    """Custom PPO learner that captures per-element batch data for inspection.

    Uses super() to delegate loss computation to PPOTorchLearner, then captures
    batch data for observation. Access captured data via self.last_batch_info.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.all_batches = []  # List of all captured batches

    @override(PPOTorchLearner)
    def compute_loss_for_module(
        self,
        *,
        module_id: ModuleID,
        config: PPOConfig,
        batch: Dict[str, Any],
        fwd_out: Dict[str, TensorType],
    ) -> TensorType:
        # Delegate to parent for actual loss computation
        total_loss = super().compute_loss_for_module(
            module_id=module_id, config=config, batch=batch, fwd_out=fwd_out
        )
        # Capture batch data for observation (doesn't affect loss)
        self._capture_batch_info(module_id, config, batch, fwd_out)
        return total_loss

    def _capture_batch_info(
        self,
        module_id: ModuleID,
        config: PPOConfig,
        batch: Dict[str, Any],
        fwd_out: Dict[str, TensorType],
    ) -> None:
        """Capture per-element batch data for observation."""
        module = self.module[module_id].unwrapped()

        # Get action distributions for computing per-element metrics
        curr_dist = module.get_train_action_dist_cls().from_logits(
            fwd_out[Columns.ACTION_DIST_INPUTS]
        )

        # Compute per-element loss components
        logp_ratio = torch.exp(
            curr_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
        )
        advantages = batch[Postprocessing.ADVANTAGES]
        surrogate_loss = torch.min(
            advantages * logp_ratio,
            advantages * torch.clamp(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
        )

        # Value function predictions and error
        if config.use_critic:
            vf_preds = module.compute_values(batch, embeddings=fwd_out.get(Columns.EMBEDDINGS))
            vf_targets = batch[Postprocessing.VALUE_TARGETS]
            vf_error = torch.pow(vf_preds - vf_targets, 2.0)
        else:
            vf_preds = vf_targets = vf_error = None

        # Append batch info to list
        self.all_batches.append({
            "batch_size": batch[Columns.OBS].shape[0],
            # Core batch data
            "observations": batch[Columns.OBS].detach().cpu().numpy(),
            "actions": batch[Columns.ACTIONS].detach().cpu().numpy(),
            "rewards": batch[Columns.REWARDS].detach().cpu().numpy(),
            "terminateds": batch[Columns.TERMINATEDS].detach().cpu().numpy(),
            "truncateds": batch[Columns.TRUNCATEDS].detach().cpu().numpy(),
            # GAE-computed values
            "advantages": advantages.detach().cpu().numpy(),
            "value_targets": vf_targets.detach().cpu().numpy(),
            "value_predictions": vf_preds.detach().cpu().numpy(),
            # Per-element loss components
            "surrogate_loss": (-surrogate_loss).detach().cpu().numpy(),
            "vf_error": vf_error.detach().cpu().numpy(),
            "entropy": curr_dist.entropy().detach().cpu().numpy(),
            "log_prob_ratio": logp_ratio.detach().cpu().numpy(),
            # Mask (False = bootstrap timestep, excluded from loss)
            "loss_mask": batch.get(Columns.LOSS_MASK).detach().cpu().numpy(),
        })

    def print_batch_info(self, batch_index: int = -1, max_elements: int = 16):
        info = self.all_batches[batch_index]
        n = min(max_elements, info["batch_size"])
        batch_num = batch_index if batch_index >= 0 else len(self.all_batches) + batch_index

        col_width = 10
        label_width = 15
        row_length = label_width + 1 + (col_width + 1) * n

        # Count masked vs unmasked elements
        if info["loss_mask"] is not None:
            mask = info["loss_mask"]
        else:
            mask = np.ones(info["batch_size"], dtype=bool)

        # Helper to format a row
        def print_row(label: str, values):
            row = f"{label:<{label_width}}|"
            for i in range(n):
                val = values[i]
                if isinstance(val, (np.floating, float)):
                    formatted = f"{val:.4f}".center(col_width)
                elif isinstance(val, np.ndarray):
                    if val.size == 1:
                        formatted = f"{val.item():.1f}".center(col_width)
                    else:
                        formatted = f"{val[0]:.1f}...".center(col_width)
                elif isinstance(val, (bool, np.bool_)):
                    # Color: green for True, red for False
                    text = "T" if val else "F"
                    color = GREEN if val else RED
                    formatted = f"{color}{text.center(col_width)}{RESET}"
                else:
                    formatted = str(val).center(col_width)
                row += formatted + "|"
            print(row)

        print(f"Batch {batch_num} of {len(self.all_batches)}")
        print("-" * row_length)
        print_row("Observation", info["observations"])
        # print_row("Action", info["actions"])
        print_row("Reward", info["rewards"])
        print_row("Terminated", info["terminateds"])
        print_row("Truncated", info["truncateds"])
        print("-" * row_length)
        print_row("Advantage", info["advantages"])
        print_row("Value Target", info["value_targets"])
        if info["value_predictions"] is not None:
            print_row("Value Pred", info["value_predictions"])
        if info["vf_error"] is not None:
            print_row("VF Error", info["vf_error"])
        print_row("Surrogate Loss", info["surrogate_loss"])
        print_row("Entropy", info["entropy"])
        print_row("Log Prob Ratio", info["log_prob_ratio"])
        print("-" * row_length)
        print_row("In Loss", mask)

        print(f"\n--- Summary (valid elements only) ---")
        print(f"Mean Surrogate Loss: {np.mean(info['surrogate_loss'][mask]):.4f}")
        valid_vf_error = info["vf_error"][mask] if info["vf_error"] is not None else None
        if valid_vf_error is not None and len(valid_vf_error) > 0:
            print(f"Mean VF Error:       {np.mean(valid_vf_error):.4f}")
        print(f"Mean Entropy:        {np.mean(info['entropy'][mask]):.4f}")
        print(f"Mean Advantage:      {np.mean(info['advantages'][mask]):.4f}\n")

    def print_all_batches(self, max_elements: int = 16):
        """Print all captured batches."""
        for i in range(len(self.all_batches)):
            self.print_batch_info(batch_index=i, max_elements=max_elements)


def run_ppo_with_loss_observation(
    max_count: int = 10,
    use_termination: bool = True,
    max_elements: int = 16,
):
    config = (
        PPOConfig()
        .environment(
            CountingEnv,
            env_config={
                "max_count": max_count,
                "use_termination": use_termination,
            },
        )
        .env_runners(
            num_env_runners=0,
            num_envs_per_env_runner=1,
            rollout_fragment_length=max_elements,
        )
        .training(
            train_batch_size_per_learner=max_elements,
            minibatch_size=max_elements,
            num_epochs=1,
            shuffle_batch_per_epoch=False,
            gamma=0.99,
            use_gae=True,
            lambda_=1.0,
        )
        .learners(
            num_learners=0,
            learner_class=PPOLossObservationLearner,
        )
    )

    # Build and train the algorithm
    algo = config.build()
    algo.train()

    # Get the learner and print all batches
    learner = algo.learner_group._learner
    learner.print_all_batches(max_elements=max_elements)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Observe PPO loss function with counting environment"
    )
    parser.add_argument(
        "--max-count",
        type=int,
        default=10,
        help="Maximum count value (observations go from 0 to max_count)",
    )
    parser.add_argument(
        "--use-termination",
        type=lambda x: x.lower() in ("true", "1", "yes"),
        default=True,
        help="End episodes with terminated (True) or truncation (False)",
    )
    parser.add_argument(
        "--max-elements",
        type=int,
        default=16,
        help="Maximum number of batch elements to display in the table",
    )

    args = parser.parse_args()
    run_ppo_with_loss_observation(
        max_count=args.max_count,
        use_termination=args.use_termination,
        max_elements=args.max_elements,
    )

@MatthewCWeston
Copy link
Contributor Author

MatthewCWeston commented Jan 20, 2026

@MatthewCWeston I generated a script that would let us review the batches passed to the learner to help identify if there are any issues.

I suspect that there might be an issue on the truncated episode in the second batch observation = 6.0 as both terminated and truncated is True. However I still don't think the problem is with the extra timestep but I suspect the problem could be with the GAE implementation. I haven't check that yet.

What are your thoughts?

@pseudo-rnd-thoughts I took a manual look at the second batch of your results, and the terminated flag shouldn't be set there, but the results you posted appear to be for the production code rather than the code in this PR. With the PR code enabled, the episode is correctly marked as truncated and not terminated, and the value target calculations come out correctly.

Namely, the episode that gets truncated after acting on 5.0 inherits the value prediction from 6.0, and the step that acts upon 5.0 gets a value target of 1.0+vf_head(OBS=6), which is what we'd expect, instead of 1.0, which it gets in the results above.

I've made a notebook that allows this to be tested on Colab:

bootstrapping_debug.ipynb

@pseudo-rnd-thoughts
Copy link
Member

pseudo-rnd-thoughts commented Jan 20, 2026

Namely, the episode that gets truncated after acting on 5.0 inherits the value prediction from 6.0, and the step that acts upon 5.0 gets a value target of 1.0+vf_head(OBS=6), which is what we'd expect, instead of 1.0, which it gets in the results above.

If this is the case then the problem is not with the adding of the extra timestep but in the GAE implementation for handling truncated timesteps would be my interpretation.

@MatthewCWeston
Copy link
Contributor Author

MatthewCWeston commented Jan 20, 2026

If this is the case then the problem is not with the adding of the extra timestep but in the GAE implementation for handling truncated timesteps would be my interpretation.

@pseudo-rnd-thoughts I think there might be a miscommunication - the value targets look exactly as they should to me, conditional on the termination/truncation values. The production code incorrectly marks the episode as terminated, causing the value head prediction to be discarded when it shouldn't be. The PR prevents this flag from being added unless the episode really is terminated, yielding a correct value target.

@pseudo-rnd-thoughts
Copy link
Member

pseudo-rnd-thoughts commented Jan 20, 2026

Ah sorry, yes, I think I misunderstood.
If I now understand correctly, we both agree that the batch data for the end of the rollout is marked as both terminated and truncated is wrong.

I've just spend 5 minutes working out why to realise that you were on the right lines the whole time. Apologies if I came across as an ass
Checking your solution I get two changes and I want to check that both are intentional (I suspect that we were both right in a way).

|   8.0    |   9.0    |   10.0   |   0.0    |      |   8.0    |   9.0    |   10.0   |   0.0    |
|  1.0000  |  1.0000  |  0.0000  |  1.0000  |  ->  |  1.0000  |  1.0000  |  0.0000  |  1.0000  |
|    F     |    T     |    T     |    F     |      |    F     |    F     |    T     |    F     |
|    F     |    F     |    F     |    F     |      |    F     |    F     |    F     |    F     |

|   5.0    |   6.0    |   0.0    |      |   5.0    |   6.0    |   0.0    |
|  1.0000  |  0.0000  |  1.0000  |  ->  |  1.0000  |  0.0000  |  1.0000  |
|    F     |    T     |    F     |      |    F     |    F     |    F     |
|    F     |    T     |    F     |      |    F     |    T     |    F     |

For the terminated episode case we lose the terminated from observation 9 (keeping it in 10) while for the truncated episode case we lose the terminated signal from 6.

I think we want the second change but not the first however I need to double check that.
Whats your thoughts @MatthewCWeston

EDIT: This is a middle ground solution that does both

terminateds = (
    [False for _ in range(len_ - 1)]
    + [bool(sa_episode.is_terminated)]
    + [bool(sa_episode.is_terminated)]  # extra timestep
)

@MatthewCWeston
Copy link
Contributor Author

MatthewCWeston commented Jan 20, 2026

@pseudo-rnd-thoughts The first case, in batch one, correctly computes a value target of 1.00 for the last real timestep in both the production implementation and the PR. Both are correct, but the production code only gets the correct answer here because lambda_ is exactly equal to 1.0 in your test case. For all other values of lambda_, this flag setup will yield an incorrect value target.

Setting the termination flag in the next-to-last timestep breaks propagation of terminal rewards when dealing with pure value bootstrapping (you can set lambda_ to something low to test this with your script; I think the automated unit tests added in this PR will also flag the issue).

Edit: Added results of testing with your script under the specified conditions:

Output with production code; lambda_=0 (error highlighted):
image

Output with PR; lambda_=0 (Correction highlighted):
image

@pseudo-rnd-thoughts pseudo-rnd-thoughts added rllib-algorithms An RLlib algorithm/Trainer is not learning. go add ONLY when ready to merge, run all tests labels Jan 20, 2026
@pseudo-rnd-thoughts
Copy link
Member

@MatthewCWeston I think your solution is correct, I just want to implement my own test to independently compute the expected loss and compare against the PPOLearner loss to double check that everything is working as expected.

@pseudo-rnd-thoughts
Copy link
Member

@MatthewCWeston I incorporated the fix in #59007 to run against a new suite of examples / testing that I've been building and Gemini raised a point about your change breaking VTrace (according to the docstring in AddOneTsToEpisodesAndTruncate line 95).
I still agree that you that this is the correct solution, I just suspect that this might have the unintended consequence of breaking VTrace.
Before we merge, I want to check that this isn't the case. What you tested the frozenlake example with and without VTrace enabled?

@MatthewCWeston
Copy link
Contributor Author

MatthewCWeston commented Jan 27, 2026

@pseudo-rnd-thoughts I had a quick look at the VTrace code. I'm not intimately familiar with that algorithm, but the relevant logic is here (where discounts is set to zero on terminal states) and here (where discounts are applied to the next states' values when calculating advantage).

Seeing as the tagging scheme that works depends on the advantage calculation scheme that an algorithm uses, with GAE requiring the one in this PR and VTrace (as far as I know) requiring the one in production, my first instinct is to have AddOneTsToEpisodesAndTruncate take an argument, and have that argument toggle between tagging schemes.

This amounts to changing line 511 of impala.py (the only instantiation of AddOneTsToEpisodesAndTruncate that's upstream from VTrace) to...

            connector.prepend(AddOneTsToEpisodesAndTruncate(vtrace=True))

...and adding an __init__ override to AddOneTsToEpisodesAndTruncate that uses the vtrace argument (defaulting to False) to set self.vtrace. self.vtrace then gets used to toggle the set of labels that get applied, as follows:

terminateds = (
    [False for _ in range(len_ - 1)]
    + [bool(sa_episode.is_terminated) and self.vtrace] # false unless VTrace
    + [bool(sa_episode.is_terminated) or self.vtrace]  # always true if using VTrace
)

Seems like the cleanest set of changes that makes both algorithms work as intended. Thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community go add ONLY when ready to merge, run all tests rllib RLlib related issues rllib-algorithms An RLlib algorithm/Trainer is not learning.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants