Skip to content

Commit f245b40

Browse files
committed
Merge branch 'main' into dev/fix_microbatch_loss_scale
2 parents a914078 + 1120aed commit f245b40

File tree

19 files changed

+193
-29
lines changed

19 files changed

+193
-29
lines changed
-147 KB
Loading
-15.8 KB
Loading
464 KB
Loading
-50.4 KB
Loading

docs/sphinx_doc/source/tutorial/example_search_email.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@ The results are shown in the following figure (the accuracy ranges from -0.1 to
4848

4949
![](../../assets/email_rollout_accuracy.png)
5050

51+
![](../../assets/email_reward_mean.png)
5152

5253
![](../../assets/email_eval_accuracy.png)

docs/sphinx_doc/source/tutorial/trinity_installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ For installing Trinity-RFT, you have three options: from source (recommended), v
66
Before installing, ensure your system meets the following requirements:
77

88
- **Python**: Version 3.10 to 3.12 (inclusive)
9-
- **CUDA**: Version 12.4 to 12.8 (inclusive)
9+
- **CUDA**: Version >= 12.6
1010
- **GPUs**: At least 2 GPUs
1111

1212
---

docs/sphinx_doc/source_zh/tutorial/example_search_email.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,6 @@ trinity run --config examples/grpo_email_search/email_search.yaml
4444

4545
![](../../assets/email_rollout_accuracy.png)
4646

47+
![](../../assets/email_reward_mean.png)
48+
4749
![](../../assets/email_eval_accuracy.png)

docs/sphinx_doc/source_zh/tutorial/trinity_installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
在安装前,请确保您的系统满足以下要求:
77

88
- **Python**:3.10 至 3.12(包含)
9-
- **CUDA**12.4 至 12.8(包含)
9+
- **CUDA**大于等于 12.6
1010
- **GPU**:至少 2 块 GPU
1111

1212
---

examples/grpo_email_search/email_search.yaml

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@ algorithm:
66
repeat_times: 8
77
optimizer:
88
lr: 1e-6
9+
policy_loss_fn: "rec"
10+
policy_loss_fn_args:
11+
epsilon_low: 0.2
12+
epsilon_high: 0.2
13+
clip_mode: "one-side"
14+
weight: "none"
15+
temp: 1.0
16+
regularizer: "none"
17+
regularizer_coef: 0.0
18+
kl_loss_fn: 'k2'
19+
kl_loss_fn_args:
20+
kl_coef: 0.0
21+
advantage_fn_args:
22+
std_cal_level: 'batch'
923
model:
1024
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507}
1125
max_response_tokens: 4096
@@ -15,8 +29,8 @@ cluster:
1529
gpu_per_node: 8
1630
buffer:
1731
total_epochs: 1
18-
batch_size: 16
19-
train_batch_size: 640 # 16*8*5
32+
batch_size: 64
33+
train_batch_size: 2560 # 64*8*5
2034
explorer_input:
2135
taskset:
2236
name: enron_train
@@ -56,6 +70,12 @@ buffer:
5670
storage_type: queue
5771
replay_buffer:
5872
enable: true
73+
# reuse_cooldown_time is None
74+
priority_fn: 'decay_limit_randomization'
75+
priority_fn_args:
76+
decay: 2.0
77+
use_count_limit: 3
78+
sigma: 2.0
5979
explorer:
6080
eval_interval: 10
6181
max_repeat_times_per_runner: 1
@@ -93,3 +113,5 @@ trainer:
93113
use_dynamic_bsz: true
94114
max_token_len_per_gpu: 16384
95115
ulysses_sequence_parallel_size: 1
116+
monitor:
117+
monitor_type: wandb

tests/algorithm/advantage_fn_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,71 @@ def test_batch_level_step_wise_grpo_advantage(self):
326326
expected_advantages = expected_advantage_value * target_exp.action_mask
327327
self.assertTrue(torch.allclose(target_exp.advantages, expected_advantages, atol=1e-6))
328328
self.assertTrue(torch.allclose(target_exp.returns, expected_advantages, atol=1e-6))
329+
330+
def test_step_wise_grpo_with_std_threshold(self):
331+
advantage_fn_cls = ADVANTAGE_FN.get("step_wise_grpo")
332+
self.assertIsNotNone(advantage_fn_cls)
333+
advantage_fn = advantage_fn_cls(epsilon=1e-6, std_threshold=0.0001)
334+
repeat_times = 5
335+
step_num = 4
336+
337+
# Create experiences with mixed reward patterns:
338+
# - task 0: all runs have same reward (0.5) -> should be filtered
339+
# - task 1: all runs have same reward (1.0) -> should be filtered
340+
# - task 2: runs have different rewards (0, 1, 2, 3, 4) -> should NOT be filtered
341+
exps = []
342+
343+
# Task 0: constant reward 0.5
344+
for k in range(step_num):
345+
for i in range(repeat_times):
346+
exps.append(
347+
Experience(
348+
eid=EID(batch=0, task=0, run=i, step=k),
349+
tokens=torch.zeros(5),
350+
prompt_length=2,
351+
reward=0.5,
352+
)
353+
)
354+
355+
# Task 1: constant reward 1.0
356+
for k in range(step_num):
357+
for i in range(repeat_times):
358+
exps.append(
359+
Experience(
360+
eid=EID(batch=0, task=1, run=i, step=k),
361+
tokens=torch.zeros(5),
362+
prompt_length=2,
363+
reward=1.0,
364+
)
365+
)
366+
367+
# Task 2: varying rewards
368+
for k in range(step_num):
369+
for i in range(repeat_times):
370+
exps.append(
371+
Experience(
372+
eid=EID(batch=0, task=2, run=i, step=k),
373+
tokens=torch.zeros(5),
374+
prompt_length=2,
375+
reward=float(i),
376+
)
377+
)
378+
379+
processed_exps, metrics = advantage_fn(exps)
380+
381+
# Only task 2 should remain (task 0 and task 1 filtered due to zero std)
382+
expected_remaining = repeat_times * step_num # task 2 only
383+
expected_filtered = 2 * repeat_times * step_num # task 0 and task 1
384+
385+
self.assertEqual(len(processed_exps), expected_remaining)
386+
self.assertIn("filtered_count", metrics)
387+
self.assertEqual(metrics["filtered_count"], expected_filtered)
388+
389+
# Verify skipped group ratio: 2 out of 3 tasks were skipped
390+
self.assertIn("skipped_group_ratio", metrics)
391+
expected_ratio = 2.0 / 3.0 # task 0 and task 1 skipped out of 3 total tasks
392+
self.assertAlmostEqual(metrics["skipped_group_ratio"], expected_ratio, places=6)
393+
394+
# Verify that all remaining experiences are from task 2
395+
for exp in processed_exps:
396+
self.assertEqual(exp.eid.task, 2)

0 commit comments

Comments
 (0)