|
1 | | -import dsp |
2 | 1 | import dspy |
3 | | -from dspy.signatures.signature import ensure_signature |
4 | | - |
5 | | -from ..primitives.program import Module |
6 | | -from .predict import Predict |
| 2 | +import inspect |
7 | 3 |
|
8 | | -# TODO: Simplify a lot. |
9 | | -# TODO: Divide Action and Action Input like langchain does for ReAct. |
| 4 | +from pydantic import BaseModel |
| 5 | +from dspy.primitives.program import Module |
| 6 | +from dspy.signatures.signature import ensure_signature |
| 7 | +from dspy.adapters.json_adapter import get_annotation_name |
| 8 | +from typing import Callable, Any, get_type_hints, get_origin, Literal |
10 | 9 |
|
11 | | -# TODO: There's a lot of value in having a stopping condition in the LM calls at `\n\nObservation:` |
| 10 | +class Tool: |
| 11 | + def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None): |
| 12 | + annotations_func = func if inspect.isfunction(func) else func.__call__ |
| 13 | + self.func = func |
| 14 | + self.name = name or getattr(func, '__name__', type(func).__name__) |
| 15 | + self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "No description") |
| 16 | + self.args = { |
| 17 | + k: v.schema() if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel) |
| 18 | + else get_annotation_name(v) |
| 19 | + for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return' |
| 20 | + } |
12 | 21 |
|
13 | | -# TODO [NEW]: When max_iters is about to be reached, reduce the set of available actions to only the Finish action. |
| 22 | + def __call__(self, *args, **kwargs): |
| 23 | + return self.func(*args, **kwargs) |
14 | 24 |
|
15 | 25 |
|
16 | 26 | class ReAct(Module): |
17 | | - def __init__(self, signature, max_iters=5, num_results=3, tools=None): |
18 | | - super().__init__() |
| 27 | + def __init__(self, signature, tools: list[Callable], max_iters=5): |
| 28 | + """ |
| 29 | + Tools is either a list of functions, callable classes, or dspy.Tool instances. |
| 30 | + """ |
| 31 | + |
19 | 32 | self.signature = signature = ensure_signature(signature) |
20 | 33 | self.max_iters = max_iters |
21 | 34 |
|
22 | | - self.tools = tools or [dspy.Retrieve(k=num_results)] |
23 | | - self.tools = {tool.name: tool for tool in self.tools} |
24 | | - |
25 | | - self.input_fields = self.signature.input_fields |
26 | | - self.output_fields = self.signature.output_fields |
27 | | - |
28 | | - assert len(self.output_fields) == 1, "ReAct only supports one output field." |
| 35 | + tools = [t if isinstance(t, Tool) or hasattr(t, 'input_variable') else Tool(t) for t in tools] |
| 36 | + tools = {tool.name: tool for tool in tools} |
29 | 37 |
|
30 | | - inputs_ = ", ".join([f"`{k}`" for k in self.input_fields.keys()]) |
31 | | - outputs_ = ", ".join([f"`{k}`" for k in self.output_fields.keys()]) |
| 38 | + inputs_ = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) |
| 39 | + outputs_ = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) |
| 40 | + instr = [f"{signature.instructions}\n"] if signature.instructions else [] |
32 | 41 |
|
33 | | - instr = [] |
34 | | - |
35 | | - if self.signature.instructions is not None: |
36 | | - instr.append(f"{self.signature.instructions}\n") |
37 | | - |
38 | 42 | instr.extend([ |
39 | | - f"You will be given {inputs_} and you will respond with {outputs_}.\n", |
40 | | - "To do this, you will interleave Thought, Action, and Observation steps.\n", |
41 | | - "Thought can reason about the current situation, and Action can be the following types:\n", |
| 43 | + f"You will be given {inputs_} and your goal is to finish with {outputs_}.\n", |
| 44 | + "To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting Observation.\n", |
| 45 | + "Thought can reason about the current situation, and Tool Name can be the following types:\n", |
42 | 46 | ]) |
43 | 47 |
|
44 | | - self.tools["Finish"] = dspy.Example( |
45 | | - name="Finish", |
46 | | - input_variable=outputs_.strip("`"), |
47 | | - desc=f"returns the final {outputs_} and finishes the task", |
| 48 | + finish_desc = f"Signals that the final outputs, i.e. {outputs_}, are now available and marks the task as complete." |
| 49 | + finish_args = {} #k: v.annotation for k, v in signature.output_fields.items()} |
| 50 | + tools["finish"] = Tool(func=lambda **kwargs: kwargs, name="finish", desc=finish_desc, args=finish_args) |
| 51 | + |
| 52 | + for idx, tool in enumerate(tools.values()): |
| 53 | + desc = tool.desc.replace("\n", " ") |
| 54 | + args = tool.args if hasattr(tool, 'args') else str({tool.input_variable: str}) |
| 55 | + desc = f"whose description is <desc>{desc}</desc>. It takes arguments {args} in JSON format." |
| 56 | + instr.append(f"({idx+1}) {tool.name}, {desc}") |
| 57 | + |
| 58 | + signature_ = ( |
| 59 | + dspy.Signature({**signature.input_fields}, "\n".join(instr)) |
| 60 | + .append("trajectory", dspy.InputField(), type_=str) |
| 61 | + .append("next_thought", dspy.OutputField(), type_=str) |
| 62 | + .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) |
| 63 | + .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) |
48 | 64 | ) |
49 | 65 |
|
50 | | - for idx, tool in enumerate(self.tools): |
51 | | - tool = self.tools[tool] |
52 | | - instr.append( |
53 | | - f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}", |
54 | | - ) |
55 | | - |
56 | | - instr = "\n".join(instr) |
57 | | - self.react = [ |
58 | | - Predict(dspy.Signature(self._generate_signature(i), instr)) |
59 | | - for i in range(1, max_iters + 1) |
60 | | - ] |
61 | | - |
62 | | - def _generate_signature(self, iters): |
63 | | - signature_dict = {} |
64 | | - for key, val in self.input_fields.items(): |
65 | | - signature_dict[key] = val |
66 | | - |
67 | | - for j in range(1, iters + 1): |
68 | | - IOField = dspy.OutputField if j == iters else dspy.InputField |
69 | | - |
70 | | - signature_dict[f"Thought_{j}"] = IOField( |
71 | | - prefix=f"Thought {j}:", |
72 | | - desc="next steps to take based on last observation", |
73 | | - ) |
74 | | - |
75 | | - tool_list = " or ".join( |
76 | | - [ |
77 | | - f"{tool.name}[{tool.input_variable}]" |
78 | | - for tool in self.tools.values() |
79 | | - if tool.name != "Finish" |
80 | | - ], |
81 | | - ) |
82 | | - signature_dict[f"Action_{j}"] = IOField( |
83 | | - prefix=f"Action {j}:", |
84 | | - desc=f"always either {tool_list} or, when done, Finish[<answer>], where <answer> is the answer to the question itself.", |
85 | | - ) |
86 | | - |
87 | | - if j < iters: |
88 | | - signature_dict[f"Observation_{j}"] = IOField( |
89 | | - prefix=f"Observation {j}:", |
90 | | - desc="observations based on action", |
91 | | - format=dsp.passages2text, |
92 | | - ) |
93 | | - |
94 | | - return signature_dict |
95 | | - |
96 | | - def act(self, output, hop): |
97 | | - try: |
98 | | - action = output[f"Action_{hop+1}"] |
99 | | - action_name, action_val = action.strip().split("\n")[0].split("[", 1) |
100 | | - action_val = action_val.rsplit("]", 1)[0] |
101 | | - |
102 | | - if action_name == "Finish": |
103 | | - return action_val |
104 | | - |
105 | | - result = self.tools[action_name](action_val) #result must be a str, list, or tuple |
106 | | - # Handle the case where 'passages' attribute is missing |
107 | | - output[f"Observation_{hop+1}"] = getattr(result, "passages", result) |
108 | | - |
109 | | - except Exception: |
110 | | - output[f"Observation_{hop+1}"] = ( |
111 | | - "Failed to parse action. Bad formatting or incorrect action name." |
112 | | - ) |
113 | | - # raise e |
114 | | - |
115 | | - def forward(self, **kwargs): |
116 | | - args = {key: kwargs[key] for key in self.input_fields.keys() if key in kwargs} |
117 | | - |
118 | | - for hop in range(self.max_iters): |
119 | | - # with dspy.settings.context(show_guidelines=(i <= 2)): |
120 | | - output = self.react[hop](**args) |
121 | | - output[f'Action_{hop + 1}'] = output[f'Action_{hop + 1}'].split('\n')[0] |
122 | | - |
123 | | - if action_val := self.act(output, hop): |
124 | | - break |
125 | | - args.update(output) |
| 66 | + fallback_signature = ( |
| 67 | + dspy.Signature({**signature.input_fields, **signature.output_fields}) |
| 68 | + .append("trajectory", dspy.InputField(), type_=str) |
| 69 | + ) |
126 | 70 |
|
127 | | - observations = [args[key] for key in args if key.startswith("Observation")] |
| 71 | + self.tools = tools |
| 72 | + self.react = dspy.Predict(signature_) |
| 73 | + self.extract = dspy.ChainOfThought(fallback_signature) |
| 74 | + |
| 75 | + def forward(self, **input_args): |
| 76 | + trajectory = {} |
| 77 | + |
| 78 | + def format(trajectory_: dict[str, Any], last_iteration: bool): |
| 79 | + adapter = dspy.settings.adapter or dspy.ChatAdapter() |
| 80 | + blob = adapter.format_fields(dspy.Signature(f"{', '.join(trajectory_.keys())} -> x"), trajectory_) |
| 81 | + warning = f"\n\nWarning: The maximum number of iterations ({self.max_iters}) has been reached." |
| 82 | + warning += " You must now produce the finish action." |
| 83 | + return blob + (warning if last_iteration else "") |
| 84 | + |
| 85 | + for idx in range(self.max_iters): |
| 86 | + pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters-1))) |
| 87 | + |
| 88 | + trajectory[f"thought_{idx}"] = pred.next_thought |
| 89 | + trajectory[f"tool_name_{idx}"] = pred.next_tool_name |
| 90 | + trajectory[f"tool_args_{idx}"] = pred.next_tool_args |
| 91 | + |
| 92 | + try: |
| 93 | + trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) |
| 94 | + except Exception as e: |
| 95 | + trajectory[f"observation_{idx}"] = f"Failed to execute: {e}" |
| 96 | + |
| 97 | + if pred.next_tool_name == "finish": |
| 98 | + break |
128 | 99 |
|
129 | | - # assumes only 1 output field for now - TODO: handling for multiple output fields |
130 | | - return dspy.Prediction(observations=observations, **{list(self.output_fields.keys())[0]: action_val or ""}) |
| 100 | + extract = self.extract(**input_args, trajectory=format(trajectory, last_iteration=False)) |
| 101 | + return dspy.Prediction(trajectory=trajectory, **extract) |
0 commit comments