Skip to content

Commit 67a127b

Browse files
jiawei415gemini-code-assist[bot]tongyx361
authored
[algo] feat: add optimal token baseline and variance proxy (verl-project#4678)
# Optimal Token Baseline ## Main feature - Register `AdvantageEstimator.OPTIMAL_TOKEN_BASELINE`. - Extend the DP actor to emit `sum_pi_squared`, expose `calculate_sum_pi_squared` and checkpointing toggles across configs, and add a reusable `calculate_sum_pi_squared_from_logits` function. - Introduce `compute_variance_proxy_metrics` to surface signal/total power/noise diagnostics during training. - Document the method in `docs/algo/otb.md` and ship an executable example at `examples/otb_trainer/run_qwen2_5-7b.sh`. ## Usage - Enable OTB by overriding config keys (OmegaConf overlay): ```yaml algorithm.adv_estimator: optimal_token_baseline actor_rollout_ref: actor: calculate_sum_pi_squared: true sum_pi_squared_checkpointing: false # optional for long contexts rollout: n: 8 ``` - Run the example script (adjust dataset paths & WandB project as needed): ```bash bash examples/otb_trainer/run_qwen2_5-7b.sh ``` - Monitor the new variance proxies in trainer logs: `variance_proxy/proxy1_signal_strength`, `proxy2_total_power`, `proxy3_pure_noise`. ## keyNote - `actor.calculate_sum_pi_squared` requires `actor_rollout_ref.model.use_fused_kernels=False`; fused kernels must surface logits before OTB can run there. - Group sampling is mandatory (`rollout.n > 1`); with single-rollout batches OTB collapses to vanilla returns. --- UPDATE(@tongyx361 ): `compute_sum_pi_squared` is changed to `calculate_sum_pi_squared` for consistency with `calculate_entropy`. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Shawn/Yuxuan Tong <[email protected]>
1 parent f747011 commit 67a127b

File tree

13 files changed

+685
-39
lines changed

13 files changed

+685
-39
lines changed

docs/algo/otb.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Optimal Token Baseline (OTB)
2+
3+
Last updated: 12/25/2025.
4+
5+
Optimal Token Baseline (OTB) is dynamic token-level baseline for variance reduction. It weights updates based on "Realized Energy"—essentially, how much uncertainty has accumulated up to that specific token. It downweights the noisy parts and trusts the clear signals. Read [Optimal Token Baseline blog](https://richardli.xyz/optimal-token-baseline) for more details.
6+
7+
## The method: OTB
8+
9+
- OTB builds a _dynamic_ baseline that adapts to each token by tracking the “Realized Energy”—the uncertainty that has accumulated up to that token. It downweights the noisy parts and trusts the clear signals.
10+
- Unlike standard group means (which average over the padding `EOS` token ineffectively), OTB handles this naturally by computing baselines only over valid tokens.
11+
12+
## Logit-Gradient Proxy
13+
14+
- Computing true uncertainty requires expensive backward passes (calculating gradient norms per token). Instead, OTB introduces the **Logit-Gradient Proxy**: the realized energy can be estimated entirely from forward probabilities.
15+
- This means zero extra backward calls and effectively no additional runtime overhead.
16+
17+
## Mechanics at a glance
18+
19+
For each prompt group of size `N`, OTB computes rewards-to-go `G_t` and cumulative variance weights `W_t`. The optimal baseline per token is
20+
21+
```
22+
B*_t = (Σ_i G_t^{(i)} · W_t^{(i)}) / (Σ_i W_t^{(i)} + ε),
23+
W_t = Σ_{j=1}^t (1 - 2π_j + Σπ_j²),
24+
Σπ_j² = exp(logsumexp(2·logits_j) - 2·logsumexp(logits_j)).
25+
```
26+
27+
The final advantage is `(G_t - B*_t) · mask_t`, so padding tokens stay at zero.
28+
29+
## Integration in VERL
30+
31+
- `AdvantageEstimator.OPTIMAL_TOKEN_BASELINE` registers `compute_optimal_token_baseline_advantage`, invoked whenever `algorithm.adv_estimator` is set to `optimal_token_baseline`.
32+
- `ActorRolloutRefWorker.compute_log_prob` emits an additional tensor `sum_pi_squared` (Σπ² per token) when `actor.calculate_sum_pi_squared=True`. This requires disabling fused log-prob kernels, because they do not surface logits.
33+
- Trainers assert `sum_pi_squared` exists, regroup trajectories by `non_tensor_batch["uid"]`, and run the OTB calculation. If rollout IS is active, they rescale the weights by `rollout_is_weights**2` before aggregating.
34+
- In Ulysses sequence-parallel setups, the actor gathers, unpads, and returns Σπ² in the same way it handles log-probabilities, so OTB supports sharded sequence-parallel models out of the box.
35+
- `sum_pi_squared_checkpointing` is available to trade compute for memory when Σπ² tensors become large (e.g., lengthy chain-of-thought reasoning).
36+
37+
## Configuration checklist
38+
39+
- `actor_rollout_ref.actor.calculate_sum_pi_squared: true` (mandatory).
40+
- `actor_rollout_ref.model.use_fused_kernels: false` (required until fused kernels emit logits).
41+
- `algorithm.adv_estimator: optimal_token_baseline`.
42+
- Group sampling (`actor_rollout_ref.rollout.n > 1`) to unlock OTB’s variance reduction; with `n=1` the baseline collapses to returns.
43+
44+
Example OmegaConf overlay:
45+
46+
```yaml
47+
algorithm:
48+
adv_estimator: optimal_token_baseline
49+
50+
actor_rollout_ref:
51+
actor:
52+
calculate_sum_pi_squared: true
53+
sum_pi_squared_checkpointing: false # optional memory saver
54+
rollout:
55+
n: 8
56+
```
57+
58+
## Example script
59+
60+
- `examples/otb_trainer/run_qwen2_5-7b.sh`.
61+
62+
## Gradient Variance Proxy Metrics
63+
64+
All gradient-variance analysis in the Optimal Token Baseline work starts from the variance identity
65+
66+
```
67+
Var(ĝ) = E[||ĝ||²] - ||E[ĝ]||²,
68+
```
69+
70+
which states that the variance of any stochastic gradient equals the mean squared magnitude minus the squared norm of its expectation.
71+
72+
For a trajectory `τ`, the policy-gradient estimator is
73+
74+
```
75+
ĝ(τ) = ∇ log π_θ(τ) · A(τ), A(τ) = R(τ) - B.
76+
```
77+
78+
The logit-gradient proxy approximates the squared gradient norm without an extra backward pass:
79+
80+
```
81+
||ĝ(τ)||² ≈ Ŵ(τ) · A(τ)²,
82+
```
83+
84+
where `Ŵ(τ)` is the realized energy built. Given a mini-batch `{τ_i}` of size `N`, we decompose its statistics into three diagnostics:
85+
86+
- **Signal strength (squared norm of the mean gradient)**
87+
```
88+
S = || (1/N) · Σ ĝ(τ_i) ||²
89+
```
90+
- **Total power (signal + noise)**
91+
```
92+
P_total = (1/N) · Σ Ŵ(τ_i) · A(τ_i)²
93+
```
94+
- **Pure noise (estimated variance of the batch mean)**
95+
```
96+
Var_proxy = (1/(N-1)) · (P_total - S)
97+
```
98+
99+
`verl/trainer/ppo/metric_utils.py#L306` implements these diagnostics via `compute_variance_proxy_metrics`, emitting
100+
`variance_proxy/proxy1_signal_strength`,
101+
`variance_proxy/proxy2_total_power`, and
102+
`variance_proxy/proxy3_pure_noise`.
103+
104+
Tracking these metrics provides a forward-only, low-overhead view of gradient health for any advantage estimator that supplies `sum_pi_squared`.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ verl is fast with:
8080
algo/gpg.md
8181
algo/rollout_corr.md
8282
algo/rollout_corr_math.md
83+
algo/otb.md
8384

8485
.. toctree::
8586
:maxdepth: 1
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
set -x
2+
3+
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
4+
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
5+
math_train_path=$HOME/data/math/train.parquet
6+
math_test_path=$HOME/data/math/test.parquet
7+
8+
train_files="['$gsm8k_train_path', '$math_train_path']"
9+
test_files="['$gsm8k_test_path', '$math_test_path']"
10+
11+
python3 -m verl.trainer.main_ppo \
12+
algorithm.adv_estimator=optimal_token_baseline \
13+
data.train_files="$train_files" \
14+
data.val_files="$test_files" \
15+
data.train_batch_size=128 \
16+
data.max_prompt_length=1024 \
17+
data.max_response_length=2048 \
18+
data.filter_overlong_prompts=True \
19+
data.truncation='error' \
20+
actor_rollout_ref.model.path=Qwen/Qwen2.5-7B \
21+
actor_rollout_ref.model.use_remove_padding=True \
22+
actor_rollout_ref.model.use_fused_kernels=False \
23+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
24+
actor_rollout_ref.actor.use_dynamic_bsz=False \
25+
actor_rollout_ref.actor.optim.lr=1e-6 \
26+
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
27+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
28+
actor_rollout_ref.actor.use_kl_loss=False \
29+
actor_rollout_ref.actor.entropy_coeff=0 \
30+
actor_rollout_ref.actor.calculate_sum_pi_squared=True \
31+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
32+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
33+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
34+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
35+
actor_rollout_ref.rollout.name=vllm \
36+
actor_rollout_ref.rollout.gpu_memory_utilization=0.75 \
37+
actor_rollout_ref.rollout.n=8 \
38+
trainer.logger='["console","wandb"]' \
39+
trainer.project_name='verl_grpo_example_gsm8k' \
40+
trainer.experiment_name='qwen2_5-7b-otb' \
41+
trainer.n_gpus_per_node=8 \
42+
trainer.nnodes=1 \
43+
trainer.save_freq=-1 \
44+
trainer.test_freq=5 \
45+
trainer.total_epochs=15 $@

tests/workers/actor/test_special_dp_actor.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def test_compute_log_prob(self):
174174
"""Test compute_log_prob method"""
175175
data = self._create_test_data_for_compute_log_prob()
176176

177-
log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=True)
177+
outputs = self.actor.compute_log_prob(data, calculate_entropy=True)
178+
log_probs = outputs["log_probs"]
179+
entropys = outputs["entropys"]
178180

179181
batch_size = data.batch["responses"].shape[0]
180182
response_length = data.batch["responses"].shape[1]
@@ -183,25 +185,26 @@ def test_compute_log_prob(self):
183185
self.assertEqual(log_probs.shape, (batch_size, response_length))
184186
self.assertTrue(torch.all(torch.isfinite(log_probs)))
185187

186-
self.assertIsInstance(entropies, torch.Tensor)
187-
self.assertEqual(entropies.shape, (batch_size, response_length))
188-
self.assertTrue(torch.all(torch.isfinite(entropies)))
189-
self.assertTrue(torch.all(entropies >= 0)) # Entropy should be non-negative
188+
self.assertIsInstance(entropys, torch.Tensor)
189+
self.assertEqual(entropys.shape, (batch_size, response_length))
190+
self.assertTrue(torch.all(torch.isfinite(entropys)))
191+
self.assertTrue(torch.all(entropys >= 0)) # Entropy should be non-negative
190192

191193
def test_compute_log_prob_without_entropy(self):
192194
"""Test compute_log_prob method without entropy calculation"""
193195
data = self._create_test_data_for_compute_log_prob()
194196

195-
log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=False)
197+
outputs = self.actor.compute_log_prob(data, calculate_entropy=False)
198+
log_probs = outputs["log_probs"]
199+
entropys = outputs.get("entropys", None)
196200

197201
batch_size = data.batch["responses"].shape[0]
198202
response_length = data.batch["responses"].shape[1]
199203

200204
self.assertIsInstance(log_probs, torch.Tensor)
201205
self.assertEqual(log_probs.shape, (batch_size, response_length))
202206
self.assertTrue(torch.all(torch.isfinite(log_probs)))
203-
204-
self.assertIsNone(entropies)
207+
self.assertIsNone(entropys)
205208

206209
def test_update_policy(self):
207210
"""Test update_policy method"""
@@ -259,7 +262,9 @@ def test_dataparallelppoactor_with_qwen3_model(self):
259262
qwen_actor = DataParallelPPOActor(config=self.config, actor_module=qwen_model, actor_optimizer=qwen_optimizer)
260263

261264
data = self._create_test_data_for_compute_log_prob()
262-
log_probs, entropies = qwen_actor.compute_log_prob(data, calculate_entropy=True)
265+
outputs = qwen_actor.compute_log_prob(data, calculate_entropy=True)
266+
log_probs = outputs["log_probs"]
267+
entropys = outputs["entropys"]
263268

264269
batch_size = data.batch["responses"].shape[0]
265270
response_length = data.batch["responses"].shape[1]
@@ -268,10 +273,10 @@ def test_dataparallelppoactor_with_qwen3_model(self):
268273
self.assertEqual(log_probs.shape, (batch_size, response_length))
269274
self.assertTrue(torch.all(torch.isfinite(log_probs)))
270275

271-
self.assertIsInstance(entropies, torch.Tensor)
272-
self.assertEqual(entropies.shape, (batch_size, response_length))
273-
self.assertTrue(torch.all(torch.isfinite(entropies)))
274-
self.assertTrue(torch.all(entropies >= 0))
276+
self.assertIsInstance(entropys, torch.Tensor)
277+
self.assertEqual(entropys.shape, (batch_size, response_length))
278+
self.assertTrue(torch.all(torch.isfinite(entropys)))
279+
self.assertTrue(torch.all(entropys >= 0))
275280

276281
policy_data = self._create_test_data_for_update_policy()
277282
metrics = qwen_actor.update_policy(policy_data)

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ actor_rollout_ref:
123123
entropy_from_logits_with_chunking: false
124124
entropy_checkpointing: false
125125
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
126+
calculate_sum_pi_squared: false
127+
sum_pi_squared_checkpointing: false
126128
ref:
127129
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
128130
strategy: ${actor_rollout_ref.actor.strategy}

verl/trainer/config/actor/dp_actor.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,11 @@ entropy_from_logits_with_chunking: False
4040
entropy_checkpointing: False
4141

4242
# Whether to remove padding tokens in inputs during training
43-
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
43+
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
44+
45+
# This computes Σπ² needed for the Logit-Gradient Norm proxy W(τ) = Σ_t[1 - 2π_t + Σπ²]
46+
# c.f. https://yingru.notion.site/The-Optimal-Token-Baseline-399211a558b782cfa936014c0d42dfb8
47+
calculate_sum_pi_squared: False
48+
49+
# Enable gradient checkpointing for sum_pi_squared computation (saves memory)
50+
sum_pi_squared_checkpointing: False

0 commit comments

Comments
 (0)