Skip to content

Commit d2664e2

Browse files
committed
Fix pre-commit
1 parent 1c7daa4 commit d2664e2

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

trinity/common/workflows/math_ruler_workflow.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from trinity.common.models.model import ModelWrapper
1010
from trinity.common.rewards.math_reward import MathRewardFn
1111
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
12-
1312
from trinity.utils.log import get_logger
1413

1514
logger = get_logger(__name__)
@@ -18,7 +17,7 @@
1817
@WORKFLOWS.register_module("math_ruler_workflow")
1918
class MathRULERWorkflow(SimpleWorkflow):
2019
"""A workflow for math with RULER reward function.
21-
20+
2221
Modified from `MathWorkflow`.
2322
"""
2423

@@ -52,10 +51,9 @@ def reset(self, task: Task):
5251
# call the SimpleWorkflow.reset
5352
super().reset(task)
5453

55-
5654
def run(self) -> List[Experience]:
5755
"""Modified from SimpleWorkflow.run"""
58-
56+
5957
messages = self.format_messages()
6058

6159
self.logger.debug("start chat")
@@ -78,8 +76,11 @@ def run(self) -> List[Experience]:
7876
self.logger.debug(
7977
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, gold_reward: {gold_reward}"
8078
)
81-
79+
8280
# === RULER scores as rewards ===
81+
assert (
82+
self.auxiliary_models is not None
83+
), "Current implementation of RULER requires that auxiliary_models is not None."
8384
ruler_scores = self.get_ruler_scores(responses=responses, judger=self.auxiliary_models[0])
8485
for i, response in enumerate(responses):
8586
response.reward = ruler_scores[i]
@@ -120,7 +121,7 @@ def get_ruler_scores(self, responses: List, judger: Any) -> List[float]:
120121
Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses + 1}]
121122
"""
122123

123-
# Step 2: invoke judger LLM
124+
# Step 2: invoke judger LLM
124125
messages = [
125126
{"role": "system", "content": ruler_system_prompt},
126127
{"role": "user", "content": ruler_user_prompt},
@@ -139,7 +140,7 @@ def get_ruler_scores(self, responses: List, judger: Any) -> List[float]:
139140
lst_as_str = judger_response[idx1 : (idx2 + 1)]
140141
try:
141142
scores = eval(lst_as_str)
142-
scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1]
143+
scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1]
143144
return scores
144145
except Exception:
145146
logger.warning("Unable to parse the list in judger response, set scores to all zero.")

0 commit comments

Comments
 (0)