You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[algo] feat: add optimal token baseline and variance proxy (verl-project#4678)
# Optimal Token Baseline
## Main feature
- Register `AdvantageEstimator.OPTIMAL_TOKEN_BASELINE`.
- Extend the DP actor to emit `sum_pi_squared`, expose
`calculate_sum_pi_squared` and checkpointing toggles across configs, and
add a reusable `calculate_sum_pi_squared_from_logits` function.
- Introduce `compute_variance_proxy_metrics` to surface signal/total
power/noise diagnostics during training.
- Document the method in `docs/algo/otb.md` and ship an executable
example at `examples/otb_trainer/run_qwen2_5-7b.sh`.
## Usage
- Enable OTB by overriding config keys (OmegaConf overlay):
```yaml
algorithm.adv_estimator: optimal_token_baseline
actor_rollout_ref:
actor:
calculate_sum_pi_squared: true
sum_pi_squared_checkpointing: false # optional for long contexts
rollout:
n: 8
```
- Run the example script (adjust dataset paths & WandB project as
needed):
```bash
bash examples/otb_trainer/run_qwen2_5-7b.sh
```
- Monitor the new variance proxies in trainer logs:
`variance_proxy/proxy1_signal_strength`, `proxy2_total_power`,
`proxy3_pure_noise`.
## keyNote
- `actor.calculate_sum_pi_squared` requires
`actor_rollout_ref.model.use_fused_kernels=False`; fused kernels must
surface logits before OTB can run there.
- Group sampling is mandatory (`rollout.n > 1`); with single-rollout
batches OTB collapses to vanilla returns.
---
UPDATE(@tongyx361 ): `compute_sum_pi_squared` is changed to
`calculate_sum_pi_squared` for consistency with `calculate_entropy`.
---------
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Shawn/Yuxuan Tong <[email protected]>
Optimal Token Baseline (OTB) is dynamic token-level baseline for variance reduction. It weights updates based on "Realized Energy"—essentially, how much uncertainty has accumulated up to that specific token. It downweights the noisy parts and trusts the clear signals. Read [Optimal Token Baseline blog](https://richardli.xyz/optimal-token-baseline) for more details.
6
+
7
+
## The method: OTB
8
+
9
+
- OTB builds a _dynamic_ baseline that adapts to each token by tracking the “Realized Energy”—the uncertainty that has accumulated up to that token. It downweights the noisy parts and trusts the clear signals.
10
+
- Unlike standard group means (which average over the padding `EOS` token ineffectively), OTB handles this naturally by computing baselines only over valid tokens.
11
+
12
+
## Logit-Gradient Proxy
13
+
14
+
- Computing true uncertainty requires expensive backward passes (calculating gradient norms per token). Instead, OTB introduces the **Logit-Gradient Proxy**: the realized energy can be estimated entirely from forward probabilities.
15
+
- This means zero extra backward calls and effectively no additional runtime overhead.
16
+
17
+
## Mechanics at a glance
18
+
19
+
For each prompt group of size `N`, OTB computes rewards-to-go `G_t` and cumulative variance weights `W_t`. The optimal baseline per token is
The final advantage is `(G_t - B*_t) · mask_t`, so padding tokens stay at zero.
28
+
29
+
## Integration in VERL
30
+
31
+
-`AdvantageEstimator.OPTIMAL_TOKEN_BASELINE` registers `compute_optimal_token_baseline_advantage`, invoked whenever `algorithm.adv_estimator` is set to `optimal_token_baseline`.
32
+
-`ActorRolloutRefWorker.compute_log_prob` emits an additional tensor `sum_pi_squared` (Σπ² per token) when `actor.calculate_sum_pi_squared=True`. This requires disabling fused log-prob kernels, because they do not surface logits.
33
+
- Trainers assert `sum_pi_squared` exists, regroup trajectories by `non_tensor_batch["uid"]`, and run the OTB calculation. If rollout IS is active, they rescale the weights by `rollout_is_weights**2` before aggregating.
34
+
- In Ulysses sequence-parallel setups, the actor gathers, unpads, and returns Σπ² in the same way it handles log-probabilities, so OTB supports sharded sequence-parallel models out of the box.
35
+
-`sum_pi_squared_checkpointing` is available to trade compute for memory when Σπ² tensors become large (e.g., lengthy chain-of-thought reasoning).
0 commit comments