Skip to content

Commit 803dff0

Browse files
authored
Revamp ReAct, adjust Bootstrap, adjust ChatAdapter (#1713)
* JsonAdapter: Handle JSON formatting in demo's outputs * Adjustmetns for JsonAdapter * Revamp ReAct, adjust Bootstrap (handle repeat calls to a module; transpose order for max_rounds), adjust ChatAdapter (handle incomplete demos better) * Remove ReAct tests (outdated format) * Remove react tests (outdated)
1 parent 6fe6935 commit 803dff0

File tree

9 files changed

+444
-429
lines changed

9 files changed

+444
-429
lines changed

dspy/adapters/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .base import Adapter
2-
from .chat_adapter import ChatAdapter
3-
from .json_adapter import JsonAdapter
1+
from dspy.adapters.base import Adapter
2+
from dspy.adapters.chat_adapter import ChatAdapter
3+
from dspy.adapters.json_adapter import JsonAdapter

dspy/adapters/chat_adapter.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ def parse(self, signature, completion, _parse_values=True):
8585

8686
def format_turn(self, signature, values, role, incomplete=False):
8787
return format_turn(signature, values, role, incomplete)
88+
89+
def format_fields(self, signature, values):
90+
fields_with_values = {
91+
FieldInfoWithName(name=field_name, info=field_info): values.get(
92+
field_name, "Not supplied for this particular example."
93+
)
94+
for field_name, field_info in signature.fields.items()
95+
if field_name in values
96+
}
97+
98+
return format_fields(fields_with_values)
99+
88100

89101

90102
def format_blob(blob):
@@ -228,21 +240,22 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
228240
content.append(formatted_fields)
229241

230242
if role == "user":
231-
# def type_info(v):
232-
# return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
233-
# if v.annotation is not str else ""
234-
#
235-
# content.append(
236-
# "Respond with the corresponding output fields, starting with the field "
237-
# + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
238-
# + ", and then ending with the marker for `[[ ## completed ## ]]`."
239-
# )
240-
241-
content.append(
242-
"Respond with the corresponding output fields, starting with the field "
243-
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
244-
+ ", and then ending with the marker for `completed`."
245-
)
243+
def type_info(v):
244+
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
245+
if v.annotation is not str else ""
246+
247+
if not incomplete:
248+
content.append(
249+
"Respond with the corresponding output fields, starting with the field "
250+
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
251+
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
252+
)
253+
254+
# content.append(
255+
# "Respond with the corresponding output fields, starting with the field "
256+
# + ", then ".join(f"`{f}`" for f in signature.output_fields)
257+
# + ", and then ending with the marker for `completed`."
258+
# )
246259

247260
return {"role": role, "content": "\n\n".join(content).strip()}
248261

dspy/adapters/json_adapter.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ def parse(self, signature, completion, _parse_values=True):
8989

9090
def format_turn(self, signature, values, role, incomplete=False):
9191
return format_turn(signature, values, role, incomplete)
92+
93+
def format_fields(self, signature, values):
94+
fields_with_values = {
95+
FieldInfoWithName(name=field_name, info=field_info): values.get(
96+
field_name, "Not supplied for this particular example."
97+
)
98+
for field_name, field_info in signature.fields.items()
99+
if field_name in values
100+
}
101+
102+
return format_fields(role='user', fields_with_values=fields_with_values)
103+
92104

93105

94106
def parse_value(value, annotation):
@@ -241,6 +253,7 @@ def type_info(v):
241253
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
242254
if v.annotation is not str else ""
243255

256+
# TODO: Consider if not incomplete:
244257
content.append(
245258
"Respond with a JSON object in the following order of fields: "
246259
+ ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items())

dspy/predict/react.py

Lines changed: 82 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,101 @@
1-
import dsp
21
import dspy
3-
from dspy.signatures.signature import ensure_signature
4-
5-
from ..primitives.program import Module
6-
from .predict import Predict
2+
import inspect
73

8-
# TODO: Simplify a lot.
9-
# TODO: Divide Action and Action Input like langchain does for ReAct.
4+
from pydantic import BaseModel
5+
from dspy.primitives.program import Module
6+
from dspy.signatures.signature import ensure_signature
7+
from dspy.adapters.json_adapter import get_annotation_name
8+
from typing import Callable, Any, get_type_hints, get_origin, Literal
109

11-
# TODO: There's a lot of value in having a stopping condition in the LM calls at `\n\nObservation:`
10+
class Tool:
11+
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):
12+
annotations_func = func if inspect.isfunction(func) else func.__call__
13+
self.func = func
14+
self.name = name or getattr(func, '__name__', type(func).__name__)
15+
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "No description")
16+
self.args = {
17+
k: v.schema() if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel)
18+
else get_annotation_name(v)
19+
for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return'
20+
}
1221

13-
# TODO [NEW]: When max_iters is about to be reached, reduce the set of available actions to only the Finish action.
22+
def __call__(self, *args, **kwargs):
23+
return self.func(*args, **kwargs)
1424

1525

1626
class ReAct(Module):
17-
def __init__(self, signature, max_iters=5, num_results=3, tools=None):
18-
super().__init__()
27+
def __init__(self, signature, tools: list[Callable], max_iters=5):
28+
"""
29+
Tools is either a list of functions, callable classes, or dspy.Tool instances.
30+
"""
31+
1932
self.signature = signature = ensure_signature(signature)
2033
self.max_iters = max_iters
2134

22-
self.tools = tools or [dspy.Retrieve(k=num_results)]
23-
self.tools = {tool.name: tool for tool in self.tools}
24-
25-
self.input_fields = self.signature.input_fields
26-
self.output_fields = self.signature.output_fields
27-
28-
assert len(self.output_fields) == 1, "ReAct only supports one output field."
35+
tools = [t if isinstance(t, Tool) or hasattr(t, 'input_variable') else Tool(t) for t in tools]
36+
tools = {tool.name: tool for tool in tools}
2937

30-
inputs_ = ", ".join([f"`{k}`" for k in self.input_fields.keys()])
31-
outputs_ = ", ".join([f"`{k}`" for k in self.output_fields.keys()])
38+
inputs_ = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
39+
outputs_ = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])
40+
instr = [f"{signature.instructions}\n"] if signature.instructions else []
3241

33-
instr = []
34-
35-
if self.signature.instructions is not None:
36-
instr.append(f"{self.signature.instructions}\n")
37-
3842
instr.extend([
39-
f"You will be given {inputs_} and you will respond with {outputs_}.\n",
40-
"To do this, you will interleave Thought, Action, and Observation steps.\n",
41-
"Thought can reason about the current situation, and Action can be the following types:\n",
43+
f"You will be given {inputs_} and your goal is to finish with {outputs_}.\n",
44+
"To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting Observation.\n",
45+
"Thought can reason about the current situation, and Tool Name can be the following types:\n",
4246
])
4347

44-
self.tools["Finish"] = dspy.Example(
45-
name="Finish",
46-
input_variable=outputs_.strip("`"),
47-
desc=f"returns the final {outputs_} and finishes the task",
48+
finish_desc = f"Signals that the final outputs, i.e. {outputs_}, are now available and marks the task as complete."
49+
finish_args = {} #k: v.annotation for k, v in signature.output_fields.items()}
50+
tools["finish"] = Tool(func=lambda **kwargs: kwargs, name="finish", desc=finish_desc, args=finish_args)
51+
52+
for idx, tool in enumerate(tools.values()):
53+
desc = tool.desc.replace("\n", " ")
54+
args = tool.args if hasattr(tool, 'args') else str({tool.input_variable: str})
55+
desc = f"whose description is <desc>{desc}</desc>. It takes arguments {args} in JSON format."
56+
instr.append(f"({idx+1}) {tool.name}, {desc}")
57+
58+
signature_ = (
59+
dspy.Signature({**signature.input_fields}, "\n".join(instr))
60+
.append("trajectory", dspy.InputField(), type_=str)
61+
.append("next_thought", dspy.OutputField(), type_=str)
62+
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
63+
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
4864
)
4965

50-
for idx, tool in enumerate(self.tools):
51-
tool = self.tools[tool]
52-
instr.append(
53-
f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}",
54-
)
55-
56-
instr = "\n".join(instr)
57-
self.react = [
58-
Predict(dspy.Signature(self._generate_signature(i), instr))
59-
for i in range(1, max_iters + 1)
60-
]
61-
62-
def _generate_signature(self, iters):
63-
signature_dict = {}
64-
for key, val in self.input_fields.items():
65-
signature_dict[key] = val
66-
67-
for j in range(1, iters + 1):
68-
IOField = dspy.OutputField if j == iters else dspy.InputField
69-
70-
signature_dict[f"Thought_{j}"] = IOField(
71-
prefix=f"Thought {j}:",
72-
desc="next steps to take based on last observation",
73-
)
74-
75-
tool_list = " or ".join(
76-
[
77-
f"{tool.name}[{tool.input_variable}]"
78-
for tool in self.tools.values()
79-
if tool.name != "Finish"
80-
],
81-
)
82-
signature_dict[f"Action_{j}"] = IOField(
83-
prefix=f"Action {j}:",
84-
desc=f"always either {tool_list} or, when done, Finish[<answer>], where <answer> is the answer to the question itself.",
85-
)
86-
87-
if j < iters:
88-
signature_dict[f"Observation_{j}"] = IOField(
89-
prefix=f"Observation {j}:",
90-
desc="observations based on action",
91-
format=dsp.passages2text,
92-
)
93-
94-
return signature_dict
95-
96-
def act(self, output, hop):
97-
try:
98-
action = output[f"Action_{hop+1}"]
99-
action_name, action_val = action.strip().split("\n")[0].split("[", 1)
100-
action_val = action_val.rsplit("]", 1)[0]
101-
102-
if action_name == "Finish":
103-
return action_val
104-
105-
result = self.tools[action_name](action_val) #result must be a str, list, or tuple
106-
# Handle the case where 'passages' attribute is missing
107-
output[f"Observation_{hop+1}"] = getattr(result, "passages", result)
108-
109-
except Exception:
110-
output[f"Observation_{hop+1}"] = (
111-
"Failed to parse action. Bad formatting or incorrect action name."
112-
)
113-
# raise e
114-
115-
def forward(self, **kwargs):
116-
args = {key: kwargs[key] for key in self.input_fields.keys() if key in kwargs}
117-
118-
for hop in range(self.max_iters):
119-
# with dspy.settings.context(show_guidelines=(i <= 2)):
120-
output = self.react[hop](**args)
121-
output[f'Action_{hop + 1}'] = output[f'Action_{hop + 1}'].split('\n')[0]
122-
123-
if action_val := self.act(output, hop):
124-
break
125-
args.update(output)
66+
fallback_signature = (
67+
dspy.Signature({**signature.input_fields, **signature.output_fields})
68+
.append("trajectory", dspy.InputField(), type_=str)
69+
)
12670

127-
observations = [args[key] for key in args if key.startswith("Observation")]
71+
self.tools = tools
72+
self.react = dspy.Predict(signature_)
73+
self.extract = dspy.ChainOfThought(fallback_signature)
74+
75+
def forward(self, **input_args):
76+
trajectory = {}
77+
78+
def format(trajectory_: dict[str, Any], last_iteration: bool):
79+
adapter = dspy.settings.adapter or dspy.ChatAdapter()
80+
blob = adapter.format_fields(dspy.Signature(f"{', '.join(trajectory_.keys())} -> x"), trajectory_)
81+
warning = f"\n\nWarning: The maximum number of iterations ({self.max_iters}) has been reached."
82+
warning += " You must now produce the finish action."
83+
return blob + (warning if last_iteration else "")
84+
85+
for idx in range(self.max_iters):
86+
pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters-1)))
87+
88+
trajectory[f"thought_{idx}"] = pred.next_thought
89+
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
90+
trajectory[f"tool_args_{idx}"] = pred.next_tool_args
91+
92+
try:
93+
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
94+
except Exception as e:
95+
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"
96+
97+
if pred.next_tool_name == "finish":
98+
break
12899

129-
# assumes only 1 output field for now - TODO: handling for multiple output fields
130-
return dspy.Prediction(observations=observations, **{list(self.output_fields.keys())[0]: action_val or ""})
100+
extract = self.extract(**input_args, trajectory=format(trajectory, last_iteration=False))
101+
return dspy.Prediction(trajectory=trajectory, **extract)

dspy/retrieve/retrieve.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,15 @@ def __call__(self, *args, **kwargs):
4545

4646
def forward(
4747
self,
48-
query_or_queries: Union[str, List[str]],
48+
query_or_queries: Union[str, List[str]] = None,
49+
query: Optional[str] = None,
4950
k: Optional[int] = None,
5051
by_prob: bool = True,
5152
with_metadata: bool = False,
5253
**kwargs,
5354
) -> Union[List[str], Prediction, List[Prediction]]:
55+
query_or_queries = query_or_queries or query
56+
5457
# queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
5558
# queries = [query.strip().split('\n')[0].strip() for query in queries]
5659

0 commit comments

Comments
 (0)