99from trinity .common .models .model import ModelWrapper
1010from trinity .common .rewards .math_reward import MathRewardFn
1111from trinity .common .workflows .workflow import WORKFLOWS , SimpleWorkflow , Task
12-
1312from trinity .utils .log import get_logger
1413
1514logger = get_logger (__name__ )
1817@WORKFLOWS .register_module ("math_ruler_workflow" )
1918class MathRULERWorkflow (SimpleWorkflow ):
2019 """A workflow for math with RULER reward function.
21-
20+
2221 Modified from `MathWorkflow`.
2322 """
2423
@@ -52,10 +51,9 @@ def reset(self, task: Task):
5251 # call the SimpleWorkflow.reset
5352 super ().reset (task )
5453
55-
5654 def run (self ) -> List [Experience ]:
5755 """Modified from SimpleWorkflow.run"""
58-
56+
5957 messages = self .format_messages ()
6058
6159 self .logger .debug ("start chat" )
@@ -78,8 +76,11 @@ def run(self) -> List[Experience]:
7876 self .logger .debug (
7977 f"self.task_desc: { self .task_desc } , messages: { messages } , response: { response .response_text } , gold_reward: { gold_reward } "
8078 )
81-
79+
8280 # === RULER scores as rewards ===
81+ assert (
82+ self .auxiliary_models is not None
83+ ), "Current implementation of RULER requires that auxiliary_models is not None."
8384 ruler_scores = self .get_ruler_scores (responses = responses , judger = self .auxiliary_models [0 ])
8485 for i , response in enumerate (responses ):
8586 response .reward = ruler_scores [i ]
@@ -120,7 +121,7 @@ def get_ruler_scores(self, responses: List, judger: Any) -> List[float]:
120121Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution { num_responses + 1 } ]
121122"""
122123
123- # Step 2: invoke judger LLM
124+ # Step 2: invoke judger LLM
124125 messages = [
125126 {"role" : "system" , "content" : ruler_system_prompt },
126127 {"role" : "user" , "content" : ruler_user_prompt },
@@ -139,7 +140,7 @@ def get_ruler_scores(self, responses: List, judger: Any) -> List[float]:
139140 lst_as_str = judger_response [idx1 : (idx2 + 1 )]
140141 try :
141142 scores = eval (lst_as_str )
142- scores = [max (0.0 , min (1.0 , score )) for score in scores ] # clip to range [0, 1]
143+ scores = [max (0.0 , min (1.0 , score )) for score in scores ] # clip to range [0, 1]
143144 return scores
144145 except Exception :
145146 logger .warning ("Unable to parse the list in judger response, set scores to all zero." )
0 commit comments