Skip to content

Commit 1c7daa4

Browse files
committed
Init toy implementation of RULER
1 parent d193f20 commit 1c7daa4

File tree

4 files changed

+292
-0
lines changed

4 files changed

+292
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# RL on GSM8K with RULER reward
2+
3+
A toy implementation of ART's RULER on GSM8k task and GRPO.
4+
5+
https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py
6+
7+
https://art.openpipe.ai/fundamentals/ruler
8+
9+
10+
11+
The config files are located in [`gsm8k_ruler.yaml`](gsm8k_ruler.yaml) and [`train_gsm8k_ruler.yaml`](train_gsm8k_ruler.yaml).
12+
13+
Configs to pay attention to:
14+
* `default_workflow_type`: set to `math_ruler_workflow`
15+
* `auxiliary_models`: LLM-as-a-judge for RULER; need to set `max_prompt_tokens`, `max_response_tokens`, `max_model_len` appropriately
16+
* `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero)
17+
* `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences
18+
* `lr`: set to small value (2e-6) for stability, as rewards can be noisy
19+
20+
wandb metrics to pay attention to:
21+
* `reward`: reward calculated by RULER
22+
* `gold_reward`: sum of `accuracy_reward` and `format_reward`, rule-based calculation with ground truth (as in original GSM8k example)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
project: "Trinity-RFT-gsm8k-ruler"
2+
name: "qwen2.5-1.5B-gsm8k-ruler"
3+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
4+
algorithm:
5+
algorithm_type: grpo
6+
advantage_fn_args:
7+
std_threshold: 0.0001 # effectively zero
8+
repeat_times: 8
9+
model:
10+
model_path: /PATH/TO/MODEL/
11+
max_response_tokens: 1024
12+
max_model_len: 1280
13+
cluster:
14+
node_num: 1
15+
gpu_per_node: 8
16+
buffer:
17+
total_epochs: 1
18+
batch_size: 96
19+
max_retry_times: 3
20+
max_retry_interval: 1
21+
explorer_input:
22+
taskset:
23+
name: gsm8k
24+
storage_type: file
25+
path: 'openai/gsm8k'
26+
subset_name: 'main'
27+
split: 'train'
28+
format:
29+
prompt_key: 'question'
30+
response_key: 'answer'
31+
rollout_args:
32+
temperature: 1.0
33+
eval_tasksets:
34+
- name: gsm8k-eval
35+
storage_type: file
36+
path: 'openai/gsm8k'
37+
subset_name: 'main'
38+
split: 'test'
39+
format:
40+
prompt_key: 'question'
41+
response_key: 'answer'
42+
default_workflow_type: 'math_ruler_workflow'
43+
trainer_input:
44+
experience_buffer:
45+
name: gsm8k_buffer
46+
storage_type: queue
47+
# path: 'sqlite:///gsm8k.db'
48+
explorer:
49+
eval_interval: 50
50+
runner_num: 32
51+
rollout_model:
52+
engine_type: vllm_async
53+
engine_num: 2
54+
tensor_parallel_size: 1
55+
enable_prefix_caching: false
56+
enforce_eager: true
57+
dtype: bfloat16
58+
seed: 42
59+
auxiliary_models:
60+
- model_path: /PATH/TO/Qwen2.5-32B-Instruct
61+
engine_num: 1
62+
tensor_parallel_size: 2
63+
enable_thinking: false
64+
max_prompt_tokens: 12288
65+
max_response_tokens: 12288
66+
max_model_len: 16384
67+
synchronizer:
68+
sync_style: dynamic_by_explorer
69+
sync_method: 'nccl'
70+
sync_interval: 5
71+
sync_timeout: 3600
72+
trainer:
73+
trainer_type: 'verl'
74+
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
75+
save_interval: 100
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
actor_rollout_ref:
2+
hybrid_engine: True
3+
model:
4+
external_lib: null
5+
override_config: { }
6+
enable_gradient_checkpointing: True
7+
use_remove_padding: True # False
8+
actor:
9+
strategy: fsdp # This is for backward-compatibility
10+
ppo_micro_batch_size_per_gpu: 4
11+
use_dynamic_bsz: True # False
12+
ppo_max_token_len_per_gpu: 16384
13+
grad_clip: 1.0
14+
ppo_epochs: 1
15+
shuffle: False
16+
ulysses_sequence_parallel_size: 1 # sp size
17+
optim:
18+
lr: 2e-6
19+
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
20+
# min_lr_ratio: null # only useful for warmup with cosine
21+
warmup_style: constant # select from constant/cosine
22+
total_training_steps: -1 # must be override by program
23+
fsdp_config:
24+
wrap_policy:
25+
# transformer_layer_cls_to_wrap: None
26+
min_num_params: 0
27+
param_offload: False
28+
optimizer_offload: False
29+
fsdp_size: -1
30+
ref:
31+
fsdp_config:
32+
param_offload: False
33+
wrap_policy:
34+
# transformer_layer_cls_to_wrap: None
35+
min_num_params: 0
36+
log_prob_micro_batch_size_per_gpu: 4
37+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
38+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
39+
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
40+
41+
trainer:
42+
balance_batch: True
43+
# total_training_steps: null
44+
# auto: find the last ckpt to resume. If can't find, start from scratch
45+
resume_mode: auto # or auto or resume_path if
46+
default_hdfs_dir: null
47+
remove_previous_ckpt_in_save: False
48+
del_local_ckpt_after_load: False
49+
val_before_train: False
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# -*- coding: utf-8 -*-
2+
"""Math workflow with RULER."""
3+
4+
from typing import Any, List, Optional
5+
6+
import openai
7+
8+
from trinity.common.experience import Experience
9+
from trinity.common.models.model import ModelWrapper
10+
from trinity.common.rewards.math_reward import MathRewardFn
11+
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
12+
13+
from trinity.utils.log import get_logger
14+
15+
logger = get_logger(__name__)
16+
17+
18+
@WORKFLOWS.register_module("math_ruler_workflow")
19+
class MathRULERWorkflow(SimpleWorkflow):
20+
"""A workflow for math with RULER reward function.
21+
22+
Modified from `MathWorkflow`.
23+
"""
24+
25+
def __init__(
26+
self,
27+
*,
28+
task: Task,
29+
model: ModelWrapper,
30+
auxiliary_models: Optional[List[openai.OpenAI]] = None,
31+
):
32+
self.reset(task)
33+
super().__init__(
34+
task=task,
35+
model=model,
36+
auxiliary_models=auxiliary_models,
37+
)
38+
39+
def reset(self, task: Task):
40+
"""
41+
Note that in this workflow, MathRewardFn is only used for calculating the 'golden reward',
42+
whereasa the rewards used by RL training are calculated by RULER.
43+
"""
44+
45+
if task.reward_fn is None:
46+
task.reward_fn = MathRewardFn
47+
if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None:
48+
task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
49+
<think> reasoning process here </think>
50+
<answer> answer here </answer>.
51+
"""
52+
# call the SimpleWorkflow.reset
53+
super().reset(task)
54+
55+
56+
def run(self) -> List[Experience]:
57+
"""Modified from SimpleWorkflow.run"""
58+
59+
messages = self.format_messages()
60+
61+
self.logger.debug("start chat")
62+
responses = self.model.chat(messages, **self.rollout_args)
63+
64+
for i, response in enumerate(responses):
65+
gold_reward_dict = self.reward_fn( # type: ignore [misc]
66+
response=response.response_text, # type: ignore [arg-type]
67+
truth=self.truth,
68+
)
69+
70+
if response.metrics is None:
71+
response.metrics = {}
72+
73+
response.metrics.update(gold_reward_dict)
74+
gold_reward = sum(gold_reward_dict.values())
75+
response.metrics.update({"gold_reward": gold_reward})
76+
response.eid.run = i + self.run_id_base
77+
78+
self.logger.debug(
79+
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, gold_reward: {gold_reward}"
80+
)
81+
82+
# === RULER scores as rewards ===
83+
ruler_scores = self.get_ruler_scores(responses=responses, judger=self.auxiliary_models[0])
84+
for i, response in enumerate(responses):
85+
response.reward = ruler_scores[i]
86+
87+
return responses
88+
89+
def get_ruler_scores(self, responses: List, judger: Any) -> List[float]:
90+
"""Get RULER scores"""
91+
92+
num_responses = len(responses)
93+
94+
# Step 1: format prompt for judge
95+
ruler_system_prompt = f"You are a fair judge. The user will provide a question and {num_responses} candidate solutions to it. Your task is to compare the solutions, see how well they resolve the question, and assign a score within the range [0, 1] for each solution."
96+
97+
question_prompt = (
98+
f"Question: {self.task_desc}\n\n"
99+
f"""Solution format requirement: first thinks about the reasoning process in the mind and then provides the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
100+
<think> reasoning process here </think>
101+
<answer> answer here </answer>."""
102+
)
103+
104+
solutions_prompt_parts = [
105+
f"Candidate solution {i + 1}: {response.response_text}"
106+
for i, response in enumerate(responses)
107+
]
108+
solutions_prompt = "\n\n".join(solutions_prompt_parts)
109+
110+
ruler_user_prompt = f"""
111+
Below is a question and several candidate solutions.
112+
113+
{question_prompt}
114+
115+
{solutions_prompt}
116+
117+
Please assign a score within the range [0, 1] for each of them, reflecting how well they solve the question.
118+
You may compare them against each other and think step by step before returning your final scores, but keep your reasoning process brief and concise when possible.
119+
120+
Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses + 1}]
121+
"""
122+
123+
# Step 2: invoke judger LLM
124+
messages = [
125+
{"role": "system", "content": ruler_system_prompt},
126+
{"role": "user", "content": ruler_user_prompt},
127+
]
128+
completion = judger.chat.completions.create(
129+
model=judger.model_path, messages=messages, stream=False
130+
)
131+
judger_response = completion.choices[0].message.content
132+
logger.info(f"LLM judge response: {judger_response}")
133+
134+
# Step 3: extract scores from judger's response
135+
idx1, idx2 = judger_response.rfind("["), judger_response.rfind("]")
136+
if (idx1 == -1) or (idx2 == -1) or (idx1 > idx2):
137+
logger.warning("Unable to extract a list from judger response, set scores to all zero.")
138+
return [0.0 for _ in range(num_responses)]
139+
lst_as_str = judger_response[idx1 : (idx2 + 1)]
140+
try:
141+
scores = eval(lst_as_str)
142+
scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1]
143+
return scores
144+
except Exception:
145+
logger.warning("Unable to parse the list in judger response, set scores to all zero.")
146+
return [0.0 for _ in range(num_responses)]

0 commit comments

Comments
 (0)