|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +"""Math workflow with RULER.""" |
| 3 | + |
| 4 | +from typing import Any, List, Optional |
| 5 | + |
| 6 | +import openai |
| 7 | + |
| 8 | +from trinity.common.experience import Experience |
| 9 | +from trinity.common.models.model import ModelWrapper |
| 10 | +from trinity.common.rewards.math_reward import MathRewardFn |
| 11 | +from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task |
| 12 | + |
| 13 | +from trinity.utils.log import get_logger |
| 14 | + |
| 15 | +logger = get_logger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +@WORKFLOWS.register_module("math_ruler_workflow") |
| 19 | +class MathRULERWorkflow(SimpleWorkflow): |
| 20 | + """A workflow for math with RULER reward function. |
| 21 | + |
| 22 | + Modified from `MathWorkflow`. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + *, |
| 28 | + task: Task, |
| 29 | + model: ModelWrapper, |
| 30 | + auxiliary_models: Optional[List[openai.OpenAI]] = None, |
| 31 | + ): |
| 32 | + self.reset(task) |
| 33 | + super().__init__( |
| 34 | + task=task, |
| 35 | + model=model, |
| 36 | + auxiliary_models=auxiliary_models, |
| 37 | + ) |
| 38 | + |
| 39 | + def reset(self, task: Task): |
| 40 | + """ |
| 41 | + Note that in this workflow, MathRewardFn is only used for calculating the 'golden reward', |
| 42 | + whereasa the rewards used by RL training are calculated by RULER. |
| 43 | + """ |
| 44 | + |
| 45 | + if task.reward_fn is None: |
| 46 | + task.reward_fn = MathRewardFn |
| 47 | + if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None: |
| 48 | + task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., |
| 49 | +<think> reasoning process here </think> |
| 50 | +<answer> answer here </answer>. |
| 51 | +""" |
| 52 | + # call the SimpleWorkflow.reset |
| 53 | + super().reset(task) |
| 54 | + |
| 55 | + |
| 56 | + def run(self) -> List[Experience]: |
| 57 | + """Modified from SimpleWorkflow.run""" |
| 58 | + |
| 59 | + messages = self.format_messages() |
| 60 | + |
| 61 | + self.logger.debug("start chat") |
| 62 | + responses = self.model.chat(messages, **self.rollout_args) |
| 63 | + |
| 64 | + for i, response in enumerate(responses): |
| 65 | + gold_reward_dict = self.reward_fn( # type: ignore [misc] |
| 66 | + response=response.response_text, # type: ignore [arg-type] |
| 67 | + truth=self.truth, |
| 68 | + ) |
| 69 | + |
| 70 | + if response.metrics is None: |
| 71 | + response.metrics = {} |
| 72 | + |
| 73 | + response.metrics.update(gold_reward_dict) |
| 74 | + gold_reward = sum(gold_reward_dict.values()) |
| 75 | + response.metrics.update({"gold_reward": gold_reward}) |
| 76 | + response.eid.run = i + self.run_id_base |
| 77 | + |
| 78 | + self.logger.debug( |
| 79 | + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, gold_reward: {gold_reward}" |
| 80 | + ) |
| 81 | + |
| 82 | + # === RULER scores as rewards === |
| 83 | + ruler_scores = self.get_ruler_scores(responses=responses, judger=self.auxiliary_models[0]) |
| 84 | + for i, response in enumerate(responses): |
| 85 | + response.reward = ruler_scores[i] |
| 86 | + |
| 87 | + return responses |
| 88 | + |
| 89 | + def get_ruler_scores(self, responses: List, judger: Any) -> List[float]: |
| 90 | + """Get RULER scores""" |
| 91 | + |
| 92 | + num_responses = len(responses) |
| 93 | + |
| 94 | + # Step 1: format prompt for judge |
| 95 | + ruler_system_prompt = f"You are a fair judge. The user will provide a question and {num_responses} candidate solutions to it. Your task is to compare the solutions, see how well they resolve the question, and assign a score within the range [0, 1] for each solution." |
| 96 | + |
| 97 | + question_prompt = ( |
| 98 | + f"Question: {self.task_desc}\n\n" |
| 99 | + f"""Solution format requirement: first thinks about the reasoning process in the mind and then provides the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., |
| 100 | +<think> reasoning process here </think> |
| 101 | +<answer> answer here </answer>.""" |
| 102 | + ) |
| 103 | + |
| 104 | + solutions_prompt_parts = [ |
| 105 | + f"Candidate solution {i + 1}: {response.response_text}" |
| 106 | + for i, response in enumerate(responses) |
| 107 | + ] |
| 108 | + solutions_prompt = "\n\n".join(solutions_prompt_parts) |
| 109 | + |
| 110 | + ruler_user_prompt = f""" |
| 111 | +Below is a question and several candidate solutions. |
| 112 | +
|
| 113 | +{question_prompt} |
| 114 | +
|
| 115 | +{solutions_prompt} |
| 116 | +
|
| 117 | +Please assign a score within the range [0, 1] for each of them, reflecting how well they solve the question. |
| 118 | +You may compare them against each other and think step by step before returning your final scores, but keep your reasoning process brief and concise when possible. |
| 119 | +
|
| 120 | +Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses + 1}] |
| 121 | +""" |
| 122 | + |
| 123 | + # Step 2: invoke judger LLM |
| 124 | + messages = [ |
| 125 | + {"role": "system", "content": ruler_system_prompt}, |
| 126 | + {"role": "user", "content": ruler_user_prompt}, |
| 127 | + ] |
| 128 | + completion = judger.chat.completions.create( |
| 129 | + model=judger.model_path, messages=messages, stream=False |
| 130 | + ) |
| 131 | + judger_response = completion.choices[0].message.content |
| 132 | + logger.info(f"LLM judge response: {judger_response}") |
| 133 | + |
| 134 | + # Step 3: extract scores from judger's response |
| 135 | + idx1, idx2 = judger_response.rfind("["), judger_response.rfind("]") |
| 136 | + if (idx1 == -1) or (idx2 == -1) or (idx1 > idx2): |
| 137 | + logger.warning("Unable to extract a list from judger response, set scores to all zero.") |
| 138 | + return [0.0 for _ in range(num_responses)] |
| 139 | + lst_as_str = judger_response[idx1 : (idx2 + 1)] |
| 140 | + try: |
| 141 | + scores = eval(lst_as_str) |
| 142 | + scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1] |
| 143 | + return scores |
| 144 | + except Exception: |
| 145 | + logger.warning("Unable to parse the list in judger response, set scores to all zero.") |
| 146 | + return [0.0 for _ in range(num_responses)] |
0 commit comments