Skip to content

Commit 0bb3fd4

Browse files
authored
Refactor ProgramOfThought and add unit tests (#7878)
* refactor ProgramOfThought and add tests
1 parent 7e786ef commit 0bb3fd4

File tree

5 files changed

+125
-30
lines changed

5 files changed

+125
-30
lines changed

dspy/predict/program_of_thought.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,37 @@
1+
import logging
12
import re
3+
from typing import Union, Type
24

35
import dspy
4-
from dspy.signatures.signature import ensure_signature
6+
from dspy.signatures.signature import ensure_signature, Signature
57

68
from dspy.primitives.program import Module
79
from dspy.primitives.python_interpreter import PythonInterpreter
810

11+
logger = logging.getLogger(__name__)
912

1013
class ProgramOfThought(Module):
11-
def __init__(self, signature, max_iters=3):
14+
"""
15+
A DSPy module that runs Python programs to solve a problem.
16+
This module reuires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/
17+
18+
Example:
19+
```
20+
import dspy
21+
22+
lm = dspy.LM('openai/gpt-4o-mini')
23+
dspy.configure(lm=lm)
24+
pot = dspy.ProgramOfThought("question -> answer")
25+
pot(question="what is 1+1?")
26+
```
27+
"""
28+
29+
def __init__(self, signature: Union[str, Type[Signature]], max_iters=3):
30+
"""
31+
Args:
32+
signature: The signature of the module.
33+
max_iters: The maximum number of iterations to retry code generation and execution.
34+
"""
1235
super().__init__()
1336
self.signature = signature = ensure_signature(signature)
1437
self.max_iters = max_iters
@@ -56,6 +79,10 @@ def __init__(self, signature, max_iters=3):
5679
self._generate_instruction("answer"),
5780
),
5881
)
82+
# Currently, the interpreter class checks the deno availability at execution time.
83+
# We may consider checking it at the initialization time for better instruction.
84+
self.interpreter = PythonInterpreter()
85+
5986
def _generate_signature(self, mode):
6087
signature_dict = dict(self.input_fields)
6188
fields_for_mode = {
@@ -125,7 +152,7 @@ def _generate_instruction(self, mode):
125152
return "\n".join(instr)
126153

127154

128-
def parse_code(self, code_data):
155+
def _parse_code(self, code_data):
129156
code = (
130157
code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0]
131158
)
@@ -148,35 +175,42 @@ def parse_code(self, code_data):
148175
)
149176
return code_block, None
150177

151-
def execute_code(self, code):
178+
def _execute_code(self, code):
179+
"""
180+
Execute the code using PythonInterpreter and return the output or error.
181+
"""
152182
if not code:
153-
return code, None, "Error: Empty code before execution."
154-
interpreter = PythonInterpreter()
183+
return None, "Error: Empty code before execution."
184+
155185
try:
156-
output = str(interpreter.execute(code))
157-
return code, output, None
186+
output = str(self.interpreter.execute(code))
187+
return output, None
158188
except Exception as e:
159-
return code, None, str(e)
189+
return None, str(e)
190+
160191
def forward(self, **kwargs):
161192
input_kwargs = {
162193
field_name: kwargs[field_name] for field_name in self.input_fields
163194
}
164195
code_data = self.code_generate(**input_kwargs)
165-
parsed_code, error = self.parse_code(code_data)
166-
# FIXME: Don't try to execute the code if it didn't parse
167-
code, output, error = self.execute_code(parsed_code)
168-
hop = 0
169-
while hop < self.max_iters and error:
170-
print("Error in code execution")
196+
output = None
197+
code, error = self._parse_code(code_data)
198+
if not error:
199+
output, error = self._execute_code(code)
200+
hop = 1
201+
# Retying code generation and execution until no error or reach max_iters
202+
while error is not None:
203+
logger.error(f"Error in code execution: {error}")
204+
if hop == self.max_iters:
205+
self.interpreter.shutdown()
206+
raise RuntimeError(f"Max hops reached. Failed to run ProgramOfThought: {error}")
171207
input_kwargs.update({"previous_code": code, "error": error})
172208
code_data = self.code_regenerate(**input_kwargs)
173-
parsed_code, error = self.parse_code(code_data)
174-
# FIXME: Don't try to execute the code if it didn't parse
175-
code, output, error = self.execute_code(parsed_code)
209+
code, error = self._parse_code(code_data)
210+
if not error:
211+
output, error = self._execute_code(code)
176212
hop += 1
177-
if hop == self.max_iters:
178-
print("Max hops reached. Error persists.")
179-
return None
180213
input_kwargs.update({"final_generated_code": code, "code_output": output})
181214
answer_gen_result = self.generate_answer(**input_kwargs)
182-
return answer_gen_result
215+
self.interpreter.shutdown()
216+
return answer_gen_result

dspy/primitives/python_interpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ def execute(
117117
# If not valid JSON, just return raw text
118118
result = {"output": output_line}
119119

120-
# If we have an error, handle SyntaxError vs. other error
120+
# If we have an error, determine if it's a SyntaxError or other error using error.errorType.
121121
if "error" in result:
122122
error_msg = result["error"]
123-
error_type = result.get("errorType", "")
123+
error_type = result.get("errorType", "Sandbox Error")
124124
if error_type == "SyntaxError":
125-
raise SyntaxError(error_msg)
125+
raise SyntaxError(f"Invalid Python syntax. message: {error_msg}")
126126
else:
127-
raise InterpreterError(f"Sandbox Error: {error_msg}")
127+
raise InterpreterError(f"{error_type}: {error_msg}")
128128

129129
# If there's no error, return the "output" field
130130
return result.get("output", None)

dspy/primitives/runner.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ sys.stderr = old_stderr
8686
console.log(JSON.stringify({ output }));
8787
} catch (error) {
8888
// We have an error => check if it's a SyntaxError or something else
89-
const errorType = error.name || "Error";
89+
// The Python error class name is stored in error.type: https://pyodide.org/en/stable/usage/api/js-api.html#pyodide.ffi.PythonError
90+
const errorType = error.type || "Error";
91+
// error.message is mostly blank.
9092
const errorMessage = (error.message || "").trim();
9193
console.log(JSON.stringify({
9294
error: errorMessage,

tests/predict/test_program_of_thought.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1+
from unittest.mock import patch
2+
import pytest
3+
import shutil
4+
15
import dspy
26
from dspy import ProgramOfThought, Signature
37
from dspy.utils import DummyLM
48

9+
# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/
10+
is_deno_available = shutil.which("deno") is not None
511

612
class BasicQA(Signature):
713
question = dspy.InputField()
814
answer = dspy.OutputField(desc="often between 1 and 5 words")
915

1016

17+
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
1118
def test_pot_code_generation():
1219
lm = DummyLM(
1320
[
@@ -19,9 +26,11 @@ def test_pot_code_generation():
1926
pot = ProgramOfThought(BasicQA)
2027
res = pot(question="What is 1+1?")
2128
assert res.answer == "2"
29+
assert pot.interpreter.deno_process is None
2230

2331

24-
def test_pot_code_generation_with_error():
32+
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
33+
def test_pot_code_generation_with_one_error():
2534
lm = DummyLM(
2635
[
2736
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"},
@@ -34,3 +43,35 @@ def test_pot_code_generation_with_error():
3443
pot = ProgramOfThought(BasicQA)
3544
res = pot(question="What is 1+1?")
3645
assert res.answer == "2"
46+
assert pot.interpreter.deno_process is None
47+
48+
49+
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
50+
def test_pot_code_generation_persistent_errors():
51+
max_iters = 3
52+
lm = DummyLM(
53+
[
54+
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"},
55+
] * max_iters
56+
)
57+
dspy.settings.configure(lm=lm)
58+
59+
pot = ProgramOfThought(BasicQA, max_iters=max_iters)
60+
with pytest.raises(RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: ZeroDivisionError:"):
61+
pot(question="What is 1+1?")
62+
assert pot.interpreter.deno_process is None
63+
64+
65+
def test_pot_code_parse_error():
66+
max_iters = 3
67+
lm = DummyLM(
68+
[
69+
{"reasoning": "Reason_A", "generated_code": "```python\ninvalid=python=code\n```"},
70+
] * max_iters
71+
)
72+
dspy.settings.configure(lm=lm)
73+
74+
pot = ProgramOfThought(BasicQA, max_iters=max_iters)
75+
with patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, pytest.raises(RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct."):
76+
pot(question="What is 1+1?")
77+
mock_execute_code.assert_not_called()

tests/primitives/test_python_interpreter.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from dspy.primitives.python_interpreter import PythonInterpreter
1+
import shutil
2+
import pytest
3+
from dspy.primitives.python_interpreter import PythonInterpreter, InterpreterError
4+
5+
# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/
6+
if shutil.which("deno") is None:
7+
pytest.skip(reason="Deno is not installed or not in PATH")
28

39
def test_execute_simple_code():
410
interpreter = PythonInterpreter()
@@ -16,4 +22,16 @@ def test_user_variable_definitions():
1622
interpreter = PythonInterpreter()
1723
code = "result = number + 1\nresult"
1824
result = interpreter.execute(code, variables={'number': 4})
19-
assert result == 5, "User variable assignment should work"
25+
assert result == 5, "User variable assignment should work"
26+
27+
def test_failure_syntax_error():
28+
interpreter = PythonInterpreter()
29+
code = "+++"
30+
with pytest.raises(SyntaxError, match="Invalid Python syntax"):
31+
interpreter.execute(code)
32+
33+
def test_failure_zero_division():
34+
interpreter = PythonInterpreter()
35+
code = "1+0/0"
36+
with pytest.raises(InterpreterError, match="ZeroDivisionError"):
37+
interpreter.execute(code)

0 commit comments

Comments
 (0)