Skip to content

Commit 3736c74

Browse files
committed
Update yaml and doc; fix bugs
1 parent 7ad62a4 commit 3736c74

File tree

6 files changed

+71
-48
lines changed

6 files changed

+71
-48
lines changed

examples/grpo_gsm8k_trainable_ruler/README.md

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,31 @@
44
Ref: ART's RULER; Kimi-k2.
55

66

7-
Simulate a scenario where only a fraction of tasks have ground-truth answers for rule-based reward.
8-
7+
Simulate a scenario where only a fraction (`PROBABILITY_GROUND_TRUTH_AVAILABLE = 0.2`) of tasks have ground-truth answers.
8+
Two RL objectives are optimized jointly: one for solution generation, the other for RULER-reward generation.
99

1010

1111
## Configurations and Metrics
1212

13-
The config files are located in [`gsm8k_ruler.yaml`](gsm8k_ruler.yaml) and [`train_gsm8k_ruler.yaml`](train_gsm8k_ruler.yaml).
13+
The config files are located in [`gsm8k_ruler.yaml`](gsm8k_ruler.yaml) and [`train_gsm8k_trainable_ruler.yaml`](train_gsm8k_trainable_ruler.yaml).
1414

1515
Some key configs in this example are:
1616

17-
(TODO)
17+
* `default_workflow_type`: set to `math_trainable_ruler_workflow`
18+
* `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero)
19+
* `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences
20+
* `train_batch_size`: set to 960; note that one explore step can generate more than 96 * 8 = 768 experiences
21+
* `lr`: set to small value (2e-6) for stability, as rewards can be noisy
22+
1823

1924

2025
Some important metrics to pay attention to are:
2126

22-
(TODO)
27+
* `reward`: reward calculated by rule or by RULER
28+
* `gold_reward`: sum of `accuracy_reward` and `format_reward`, rule-based calculation with ground truth
29+
* `judge_success`: whether RULER successfully returns a valid score (a coarse estimation, mix up two types of experiences)
30+
* `reward_for_judger`: reward for the LLM working as a RULER reward model, calculated by mean absolute error (MAE) distance from gold scores
31+
* `eval_accuracy`: accuracy on the evaluation set (ultimate metric for success of RL)
2332

2433

2534
## Results
@@ -32,4 +41,4 @@ Compare with baseline: previous RULER workflow with Qwen2.5-1.5B-Instruct as LLM
3241

3342
## Potential improvements
3443

35-
balance number of samples / loss weights for generation vs RULER
44+
balance number of samples / loss weights for generation vs for RULER

examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
project: "Trinity-RFT-gsm8k-ruler"
2-
name: "qwen2.5-1.5B-gsm8k-ruler"
1+
project: "Trinity-RFT-gsm8k-trainable-ruler"
2+
name: "qwen2.5-1.5B-gsm8k-trainable-ruler"
33
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
44
algorithm:
55
algorithm_type: grpo
@@ -8,14 +8,16 @@ algorithm:
88
repeat_times: 8
99
model:
1010
model_path: /PATH/TO/MODEL/
11-
max_response_tokens: 1024
12-
max_model_len: 1280
11+
max_prompt_tokens: 12288
12+
max_response_tokens: 12288
13+
max_model_len: 16000 # slightly smaller than ppo_max_token_len_per_gpu (16384)
1314
cluster:
1415
node_num: 1
1516
gpu_per_node: 8
1617
buffer:
1718
total_epochs: 1
1819
batch_size: 96
20+
train_batch_size: 960
1921
explorer_input:
2022
taskset:
2123
name: gsm8k
@@ -37,7 +39,7 @@ buffer:
3739
format:
3840
prompt_key: 'question'
3941
response_key: 'answer'
40-
default_workflow_type: 'math_ruler_workflow'
42+
default_workflow_type: 'math_trainable_ruler_workflow'
4143
trainer_input:
4244
experience_buffer:
4345
name: gsm8k_buffer
@@ -47,26 +49,18 @@ explorer:
4749
runner_num: 32
4850
rollout_model:
4951
engine_type: vllm_async
50-
engine_num: 2
52+
engine_num: 4
5153
tensor_parallel_size: 1
5254
enable_prefix_caching: false
5355
enforce_eager: true
5456
dtype: bfloat16
5557
seed: 42
56-
auxiliary_models:
57-
- model_path: /PATH/TO/Qwen2.5-32B-Instruct
58-
engine_num: 1
59-
tensor_parallel_size: 2
60-
enable_thinking: false
61-
max_prompt_tokens: 12288
62-
max_response_tokens: 12288
63-
max_model_len: 16384
6458
synchronizer:
6559
sync_style: dynamic_by_explorer
6660
sync_method: 'nccl'
6761
sync_interval: 5
6862
sync_timeout: 3600
6963
trainer:
7064
trainer_type: 'verl'
71-
trainer_config_path: 'examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml'
65+
trainer_config_path: 'examples/grpo_gsm8k_trainable_ruler/train_gsm8k_trainable_ruler.yaml'
7266
save_interval: 100

examples/grpo_gsm8k_trainable_ruler/train_gsm8k_ruler.yaml renamed to examples/grpo_gsm8k_trainable_ruler/train_gsm8k_trainable_ruler.yaml

File renamed without changes.

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .eval_workflow import MathEvalWorkflow
1313
from .math_rm_workflow import MathRMWorkflow
1414
from .math_ruler_workflow import MathRULERWorkflow
15+
from .math_trainable_ruler_workflow import MathTrainableRULERWorkflow
1516
from .simple_mm_workflow import SimpleMMWorkflow
1617
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow
1718

@@ -34,5 +35,6 @@
3435
"AgentScopeReactV2MathWorkflow",
3536
"EmailSearchWorkflow",
3637
"MathRULERWorkflow",
38+
"MathTrainableRULERWorkflow",
3739
"SimpleMMWorkflow",
3840
]

trinity/common/workflows/math_ruler_workflow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def get_ruler_scores(
146146
try:
147147
scores = ast.literal_eval(lst_as_str)
148148
scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1]
149+
if len(scores) != num_responses:
150+
logger.warning(
151+
"The length of list in judger response does not match num_responses."
152+
)
153+
return False, [0.0 for _ in range(num_responses)]
149154
return True, scores
150155
except Exception:
151156
logger.warning("Unable to parse the list in judger response, set scores to all zero.")

trinity/common/workflows/math_trainable_ruler_workflow.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import ast
44
from copy import deepcopy
55
from typing import Any, List, Optional, Tuple
6-
import numpy as np
76

7+
import numpy as np
88
import openai
99

1010
from trinity.common.experience import Experience
@@ -16,7 +16,8 @@
1616
logger = get_logger(__name__)
1717

1818
# the probability that the ground truth is assumed to be available for RL
19-
PROB_GROUND_TRUTH_AVAILABLE = 0.2
19+
PROBABILITY_GROUND_TRUTH_AVAILABLE = 0.2
20+
2021

2122
@WORKFLOWS.register_module("math_trainable_ruler_workflow")
2223
class MathTrainableRULERWorkflow(SimpleWorkflow):
@@ -71,74 +72,77 @@ def run(self) -> List[Experience]:
7172
gold_reward = sum(gold_reward_dict.values())
7273
response.metrics.update({"gold_reward": gold_reward})
7374

74-
response.eid.task = self.task.task_id # task_id is set explicitly within workflow!
75+
# set task_id explicitly within workflow!
76+
response.eid.task = str(self.task.task_id)
7577
response.eid.run = i + self.run_id_base
7678

7779
gold_rewards.append(gold_reward)
78-
gold_scores_scaled.append((gold_reward + 0.1) / 1.2) # scale from range [-0.1, 1.1] to [0, 1]
80+
gold_scores_scaled.append(
81+
(gold_reward + 0.1) / 1.2
82+
) # scale from range [-0.1, 1.1] to [0, 1]
7983

8084
# Part 2: get and use RULER scores
8185
ruler_rollout_args = deepcopy(self.rollout_args)
82-
ground_truth_is_available = np.random.rand() < PROB_GROUND_TRUTH_AVAILABLE
86+
ground_truth_is_available = np.random.rand() < PROBABILITY_GROUND_TRUTH_AVAILABLE
8387

8488
if ground_truth_is_available:
89+
# Assuming that ground truth is accessible to RL:
8590
# - set exp's reward to gold reward
8691
# - generate RULER scores for repeat_times, construct ruler_responses
8792
# - return responses + ruler_responses
8893

8994
judge_success_rate, ruler_responses, ruler_scores = self.get_ruler_responses(
90-
responses=responses,
95+
responses=responses,
9196
judger=self.model, # use the policy model itself as judger!
9297
ruler_rollout_args=ruler_rollout_args,
9398
gold_scores=gold_scores_scaled,
9499
)
95100

96101
for i, response in enumerate(responses):
97102
response.reward = gold_rewards[i]
98-
response.metrics.update({"judge_success": float(judge_success_rate)})
99-
103+
response.metrics.update({"judge_success": judge_success_rate})
104+
100105
for i, ruler_response in enumerate(ruler_responses):
101-
if ruler_response.metrics is None:
102-
ruler_response.metrics = {}
103-
ruler_response.metrics.update(
104-
{
105-
"judge_success": judge_success_rate,
106-
"reward_for_judger": ruler_response.reward,
107-
}
108-
)
109-
ruler_response.eid.task = -1 * self.task.task_id # HACK to distinguish two types of experiences
106+
# if ruler_response.metrics is None:
107+
# ruler_response.metrics = {}
108+
# ruler_response.metrics.update({"judge_success": judge_success_rate})
109+
# ruler_response.metrics.update({"reward_for_judger": ruler_response.reward})
110+
111+
# set task_id explicitly, to distinguish two types of experiences!
112+
ruler_response.eid.task = str(self.task.task_id) + "-ruler"
110113
ruler_response.eid.run = i + self.run_id_base
111114

112115
return responses + ruler_responses
113116

114117
else:
118+
# Assuming that ground truth is not accessible to RL:
115119
# - generate RULER scores only once
116-
# - set exp's reward to RULER scores
120+
# - set exp's reward to RULER score
117121
# - return responses
118122

119123
ruler_rollout_args.n = 1
120124
judge_success_rate, ruler_responses, ruler_scores = self.get_ruler_responses(
121-
responses=responses,
125+
responses=responses,
122126
judger=self.model, # use the policy model itself as judger!
123127
ruler_rollout_args=ruler_rollout_args,
124128
gold_scores=None,
125129
)
126130

127131
for i, response in enumerate(responses):
128132
response.reward = ruler_scores[i]
129-
response.metrics.update({"judge_success": float(judge_success_rate)})
133+
response.metrics.update({"judge_success": judge_success_rate})
130134

131135
return responses
132136

133137
def get_ruler_responses(
134-
self,
135-
responses: List[Experience],
138+
self,
139+
responses: List[Experience],
136140
judger: Any,
137141
ruler_rollout_args: Any,
138142
gold_scores: Optional[List[float]] = None,
139-
) -> Tuple[bool, List[float]]:
143+
) -> Tuple[float, List[Experience], List[float]]:
140144
"""Get RULER scores
141-
145+
142146
Returns:
143147
judge_success_rate: float
144148
ruler_responses: List[Experience]
@@ -194,6 +198,7 @@ def get_ruler_responses(
194198

195199
if (idx1 == -1) or (idx2 == -1) or (idx1 > idx2):
196200
logger.warning("Unable to extract a list from judger response.")
201+
break
197202

198203
lst_as_str = ruler_response_text[idx1 : (idx2 + 1)]
199204
try:
@@ -203,15 +208,23 @@ def get_ruler_responses(
203208
judge_success_count += 1
204209
ruler_scores = [ruler_scores[i] + scores[i] for i in range(len(ruler_scores))]
205210
if gold_scores:
206-
mae_error = (np.array(ruler_scores) - np.array(gold_scores)).abs().mean()
211+
mae_error = (np.array(scores) - np.array(gold_scores)).abs().mean()
207212
ruler_response.reward = 1.0 - mae_error
208213
else:
209-
logger.warning("The length of list in judger response does not match num_responses.")
214+
logger.warning(
215+
"The length of list in judger response does not match num_responses."
216+
)
210217
except Exception:
211218
logger.warning("Unable to parse the list in judger response.")
212-
219+
213220
if judge_success_count > 0:
214221
ruler_scores = [score / judge_success_count for score in ruler_scores]
215222
judge_success_rate = 1.0 * judge_success_count / len(ruler_responses)
216223

224+
for ruler_response in ruler_responses:
225+
if ruler_response.metrics is None:
226+
ruler_response.metrics = {}
227+
ruler_response.metrics.update({"judge_success": judge_success_rate})
228+
ruler_response.metrics.update({"reward_for_judger": ruler_response.reward})
229+
217230
return judge_success_rate, ruler_responses, ruler_scores

0 commit comments

Comments
 (0)