Skip to content

Commit cea6f0f

Browse files
committed
Merge branch 'main' into doc/add_exp_replay_example
2 parents fca0a70 + cd82bfc commit cea6f0f

File tree

21 files changed

+180
-91
lines changed

21 files changed

+180
-91
lines changed

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class MixSampleStrategy(SampleStrategy):
105105

106106
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
107107
metrics = {}
108-
with Timer(metrics, "read_time"):
108+
with Timer(metrics, "time/read_experience"):
109109
usual_exp_list = await self.usual_exp_buffer.read_async()
110110
for exp in usual_exp_list:
111111
if exp.info is None:
@@ -131,7 +131,7 @@ class MixSampleStrategy(SampleStrategy):
131131
exp_list = usual_exp_list + expert_exp_list
132132
repr_samples = representative_sample(exp_list)
133133

134-
with Timer(metrics, "gather_time"):
134+
with Timer(metrics, "time/gather_experience"):
135135
exps = Experiences.gather_experiences(
136136
experiences=exp_list,
137137
pad_token_id=self.pad_token_id, # type: ignore [arg-type]

docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class MixSampleStrategy(SampleStrategy):
9797

9898
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
9999
metrics = {}
100-
with Timer(metrics, "read_time"):
100+
with Timer(metrics, "time/read_experience"):
101101
usual_exp_list = await self.usual_exp_buffer.read_async()
102102
for exp in usual_exp_list:
103103
if exp.info is None:
@@ -123,7 +123,7 @@ class MixSampleStrategy(SampleStrategy):
123123
exp_list = usual_exp_list + expert_exp_list
124124
repr_samples = representative_sample(exp_list)
125125

126-
with Timer(metrics, "gather_time"):
126+
with Timer(metrics, "time/gather_experience"):
127127
exps = Experiences.gather_experiences(
128128
experiences=exp_list,
129129
pad_token_id=self.pad_token_id, # type: ignore [arg-type]

tests/algorithm/advantage_fn_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def test_grpo_reward_std(self):
107107

108108
exps, metrics = advantage_fn(exps)
109109
self.assertEqual(len(exps), 0)
110-
self.assertIn("group_advantages/skipped_count/mean", metrics)
111-
self.assertEqual(metrics["group_advantages/skipped_count/mean"], 5)
110+
self.assertIn("filtered_count", metrics)
111+
self.assertEqual(metrics["filtered_count"], 15)
112112

113113
def test_grpo_correct_bias(self):
114114
advantage_fn_cls = ADVANTAGE_FN.get("grpo")

tests/buffer/experience_pipeline_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def test_experience_pipeline(self):
7373
experiences = get_experiences(task_num=task_num, repeat_times=repeat_times)
7474
metrics = await pipeline.process.remote(experiences)
7575
self.assertEqual(
76-
metrics["pipeline/experience_count"], task_num * (repeat_times - 1)
76+
metrics["experience_pipeline/experience_count"], task_num * (repeat_times - 1)
7777
) # first experience of each task will be filtered out by the reward filter
7878

7979
# tests

tests/buffer/queue_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ async def test_priority_queue_reuse_count_control(self):
326326
path=BUFFER_FILE_PATH,
327327
replay_buffer=ReplayBufferConfig(
328328
enable=True,
329-
priority_fn="linear_decay_use_count_control_randomization",
329+
priority_fn="decay_limit_randomization",
330330
reuse_cooldown_time=0.5,
331331
priority_fn_args={"decay": 1.2, "use_count_limit": 2, "sigma": 0.0},
332332
),

tests/explorer/explorer_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def test_explorer(self):
108108
eval_metrics = parser.metric_list("eval")
109109
self.assertTrue(len(eval_metrics) == 0)
110110
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
111-
self.assertTrue(parser.metric_exist("pipeline/experience_count"))
112-
experience_counts = parser.metric_values("pipeline/experience_count")
111+
self.assertTrue(parser.metric_exist("experience_pipeline/experience_count"))
112+
experience_counts = parser.metric_values("experience_pipeline/experience_count")
113113
self.assertTrue(len(experience_counts) == 4)
114114
for count in experience_counts:
115115
self.assertTrue(count >= 0)

tests/trainer/trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def test_trainer(self):
819819
self.assertTrue(len(rollout_metrics) > 0)
820820
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
821821
self.assertEqual(
822-
parser.metric_values("pipeline/experience_count")[1], 16
822+
parser.metric_values("experience_pipeline/experience_count")[1], 16
823823
) # 16 rft experiences
824824
# test actor metrics
825825
actor_metrics = parser.metric_list("actor")

trinity/algorithm/advantage_fn/advantage_fn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
7676
for group_id, group_exps in exp_groups.items():
7777
group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps)
7878
metric_list.append(group_metrics)
79-
try:
80-
metrics = gather_metrics(metric_list, "group_advantages")
81-
except ValueError:
82-
metrics = {} # empty metric list causes ValueError, ignore it
79+
metrics = gather_metrics(metric_list, "group_advantages")
8380
exps = [exp for group in exp_groups.values() for exp in group] # Flatten the list
8481
return exps, metrics
8582

trinity/algorithm/advantage_fn/grpo_advantage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ def process(self, exps):
214214
group_id, group_exps, precomputed_std=precomputed_std
215215
)
216216
metric_list.append(group_metrics)
217-
try:
218-
# TODO: sum skipped count
219-
metrics = gather_metrics(metric_list, "group_advantages")
220-
except ValueError:
221-
metrics = {} # empty metric list causes ValueError, ignore it
217+
218+
# Update the filtered_count metric
219+
filtered_count = sum(metric.pop("skipped_count", 0) for metric in metric_list)
220+
metrics = gather_metrics(metric_list, "group_advantages")
221+
metrics["filtered_count"] = filtered_count
222222
if self.duplicate_experiences and self.std_threshold is not None:
223223
exps = self._duplicate_experiences(exp_groups)
224224
else:

trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,8 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
142142
cnt += len(exps)
143143
result_exps.extend(exps)
144144

145-
try:
146-
metrics = gather_metrics(metric_list, "group_advantages")
147-
metrics["experience_count"] = cnt
148-
except ValueError:
149-
metrics = {} # empty metric list causes ValueError, ignore it
145+
metrics = gather_metrics(metric_list, "group_advantages")
146+
metrics["experience_count"] = cnt
150147
return result_exps, metrics
151148

152149
def __call__(self, exps, **kwargs):

0 commit comments

Comments
 (0)