Skip to content

Commit fb723d1

Browse files
authored
Bugfix: SGRTool calling agent derives _prepare_tools() from SGRAgent (#132)
* Update sgr_tool_calling_agent.py * test added * Update __init__.py
1 parent 7c396d1 commit fb723d1

File tree

3 files changed

+318
-3
lines changed

3 files changed

+318
-3
lines changed

sgr_agent_core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
# Version info
8-
__version__ = "0.5.0"
8+
__version__ = "0.5.1"
99
__author__ = "sgr-agent-core-team"
1010

1111
# Base classes (from direct .py files)

sgr_agent_core/agents/sgr_tool_calling_agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from openai import AsyncOpenAI, pydantic_function_tool
44

55
from sgr_agent_core.agent_config import AgentConfig
6-
from sgr_agent_core.agents.sgr_agent import SGRAgent
6+
from sgr_agent_core.base_agent import BaseAgent
77
from sgr_agent_core.models import AgentStatesEnum
88
from sgr_agent_core.tools import (
99
BaseTool,
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
class SGRToolCallingAgent(SGRAgent):
15+
class SGRToolCallingAgent(BaseAgent):
1616
"""Agent that uses OpenAI native function calling to select and execute
1717
tools based on SGR like a reasoning scheme."""
1818

@@ -122,3 +122,12 @@ async def _select_action_phase(self, reasoning: ReasoningTool) -> BaseTool:
122122
f"{self._context.iteration}-action", tool.tool_name, tool.model_dump_json()
123123
)
124124
return tool
125+
126+
async def _action_phase(self, tool: BaseTool) -> str:
127+
result = await tool(self._context, self.config)
128+
self.conversation.append(
129+
{"role": "tool", "content": result, "tool_call_id": f"{self._context.iteration}-action"}
130+
)
131+
self.streaming_generator.add_chunk_from_str(f"{result}\n")
132+
self._log_tool_execution(tool, result)
133+
return result

tests/test_agent_e2e.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""End-to-end tests for agent execution workflow."""
2+
3+
from typing import Type
4+
from unittest.mock import Mock
5+
6+
import pytest
7+
from openai import AsyncOpenAI
8+
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall
9+
from openai.types.chat.chat_completion import Choice
10+
11+
from sgr_agent_core.agent_definition import AgentConfig, ExecutionConfig, LLMConfig, PromptsConfig
12+
from sgr_agent_core.agents import SGRAgent, SGRToolCallingAgent, ToolCallingAgent
13+
from sgr_agent_core.models import AgentStatesEnum
14+
from sgr_agent_core.next_step_tool import NextStepToolsBuilder
15+
from sgr_agent_core.tools import AdaptPlanTool, FinalAnswerTool, ReasoningTool
16+
17+
18+
class MockStream:
19+
"""Mock OpenAI stream object."""
20+
21+
def __init__(self, final_completion_data: dict):
22+
self._final_completion_data = final_completion_data
23+
24+
async def __aenter__(self):
25+
return self
26+
27+
async def __aexit__(self, exc_type, exc_val, exc_tb):
28+
pass
29+
30+
def __aiter__(self):
31+
return self
32+
33+
async def __anext__(self):
34+
raise StopAsyncIteration
35+
36+
async def get_final_completion(self) -> ChatCompletion:
37+
message = ChatCompletionMessage(
38+
role="assistant",
39+
content=self._final_completion_data.get("content"),
40+
tool_calls=self._final_completion_data.get("tool_calls"),
41+
)
42+
if "parsed" in self._final_completion_data:
43+
setattr(message, "parsed", self._final_completion_data["parsed"])
44+
45+
return ChatCompletion(
46+
id="test-completion-id",
47+
choices=[Choice(index=0, message=message, finish_reason="stop")],
48+
created=1234567890,
49+
model="gpt-4o-mini",
50+
object="chat.completion",
51+
)
52+
53+
54+
def _create_tool_call(tool: Type, call_id: str) -> ChatCompletionMessageToolCall:
55+
tool_call = Mock(spec=ChatCompletionMessageToolCall)
56+
tool_call.id = call_id
57+
tool_call.type = "function"
58+
tool_call.function = Mock()
59+
tool_call.function.name = tool.tool_name
60+
tool_call.function.parsed_arguments = tool
61+
return tool_call
62+
63+
64+
def _create_next_step_tool_response(tool_class: Type, tool_data: dict, reasoning_data: dict) -> Type:
65+
NextStepTools = NextStepToolsBuilder.build_NextStepTools([tool_class])
66+
tool_dict = tool_data.copy()
67+
tool_dict["tool_name_discriminator"] = tool_class.tool_name
68+
return NextStepTools(function=tool_dict, **reasoning_data)
69+
70+
71+
def create_mock_openai_client_for_sgr_agent(action_tool_1: Type, action_tool_2: Type) -> AsyncOpenAI:
72+
client = Mock(spec=AsyncOpenAI)
73+
74+
response_1 = _create_next_step_tool_response(
75+
action_tool_1,
76+
{
77+
"reasoning": "Plan needs to be adapted based on initial findings",
78+
"original_goal": "Research task",
79+
"new_goal": "Updated research goal",
80+
"plan_changes": ["Change 1", "Change 2"],
81+
"next_steps": ["Step 1", "Step 2", "Step 3"],
82+
},
83+
{
84+
"reasoning_steps": ["Step 1: Analyze task", "Step 2: Plan adaptation"],
85+
"current_situation": "Initial research phase",
86+
"plan_status": "Plan needs adaptation",
87+
"enough_data": False,
88+
"remaining_steps": ["Adapt plan", "Continue research"],
89+
"task_completed": False,
90+
},
91+
)
92+
93+
response_2 = _create_next_step_tool_response(
94+
action_tool_2,
95+
{
96+
"reasoning": "Task completed successfully",
97+
"completed_steps": ["Step 1", "Step 2"],
98+
"answer": "Final answer to the research task",
99+
"status": AgentStatesEnum.COMPLETED,
100+
},
101+
{
102+
"reasoning_steps": ["Step 1: Complete research", "Step 2: Finalize answer"],
103+
"current_situation": "Research completed",
104+
"plan_status": "All steps completed",
105+
"enough_data": True,
106+
"remaining_steps": ["Finalize"],
107+
"task_completed": True,
108+
},
109+
)
110+
111+
call_count = {"count": 0}
112+
113+
def mock_stream(**kwargs):
114+
call_count["count"] += 1
115+
response = response_1 if call_count["count"] == 1 else response_2
116+
return MockStream(final_completion_data={"parsed": response})
117+
118+
client.chat.completions.stream = Mock(side_effect=mock_stream)
119+
return client
120+
121+
122+
def create_mock_openai_client_for_tool_calling_agent(action_tool_1: Type, action_tool_2: Type) -> AsyncOpenAI:
123+
client = Mock(spec=AsyncOpenAI)
124+
125+
tool_1 = action_tool_1(
126+
reasoning="Plan needs to be adapted",
127+
original_goal="Research task",
128+
new_goal="Updated research goal",
129+
plan_changes=["Change 1", "Change 2"],
130+
next_steps=["Step 1", "Step 2", "Step 3"],
131+
)
132+
133+
tool_2 = action_tool_2(
134+
reasoning="Task completed successfully",
135+
completed_steps=["Step 1", "Step 2"],
136+
answer="Final answer to the research task",
137+
status=AgentStatesEnum.COMPLETED,
138+
)
139+
140+
call_count = {"count": 0}
141+
142+
def mock_stream(**kwargs):
143+
call_count["count"] += 1
144+
tool = tool_1 if call_count["count"] == 1 else tool_2
145+
return MockStream(
146+
final_completion_data={
147+
"content": None,
148+
"role": "assistant",
149+
"tool_calls": [_create_tool_call(tool, f"call_{call_count['count']}")],
150+
}
151+
)
152+
153+
client.chat.completions.stream = Mock(side_effect=mock_stream)
154+
return client
155+
156+
157+
def create_mock_openai_client_for_sgr_tool_calling_agent(action_tool_1: Type, action_tool_2: Type) -> AsyncOpenAI:
158+
client = Mock(spec=AsyncOpenAI)
159+
160+
reasoning_tools = [
161+
ReasoningTool(
162+
reasoning_steps=["Step 1: Analyze", "Step 2: Plan"],
163+
current_situation="Initial research phase",
164+
plan_status="Plan needs adaptation",
165+
enough_data=False,
166+
remaining_steps=["Adapt plan", "Continue"],
167+
task_completed=False,
168+
),
169+
ReasoningTool(
170+
reasoning_steps=["Step 1: Complete", "Step 2: Finalize"],
171+
current_situation="Research completed",
172+
plan_status="All steps completed",
173+
enough_data=True,
174+
remaining_steps=["Finalize"],
175+
task_completed=True,
176+
),
177+
]
178+
179+
action_tools = [
180+
action_tool_1(
181+
reasoning="Plan needs to be adapted",
182+
original_goal="Research task",
183+
new_goal="Updated research goal",
184+
plan_changes=["Change 1", "Change 2"],
185+
next_steps=["Step 1", "Step 2", "Step 3"],
186+
),
187+
action_tool_2(
188+
reasoning="Task completed successfully",
189+
completed_steps=["Step 1", "Step 2"],
190+
answer="Final answer to the research task",
191+
status=AgentStatesEnum.COMPLETED,
192+
),
193+
]
194+
195+
reasoning_count = {"count": 0}
196+
action_count = {"count": 0}
197+
198+
def mock_stream(**kwargs):
199+
is_reasoning = (
200+
"tool_choice" in kwargs
201+
and isinstance(kwargs.get("tool_choice"), dict)
202+
and kwargs["tool_choice"].get("function", {}).get("name") == ReasoningTool.tool_name
203+
)
204+
205+
if is_reasoning:
206+
reasoning_count["count"] += 1
207+
tool = reasoning_tools[reasoning_count["count"] - 1]
208+
call_id = f"reasoning_{reasoning_count['count']}"
209+
else:
210+
tools_param = kwargs.get("tools")
211+
if tools_param is not None and not isinstance(tools_param, list):
212+
raise TypeError(
213+
f"SGRToolCallingAgent._prepare_tools() must return a list, "
214+
f"but got {type(tools_param).__name__}. "
215+
f"Override _prepare_tools() to return list instead of NextStepToolStub."
216+
)
217+
action_count["count"] += 1
218+
tool = action_tools[action_count["count"] - 1]
219+
call_id = f"action_{action_count['count']}"
220+
221+
return MockStream(
222+
final_completion_data={
223+
"content": None,
224+
"role": "assistant",
225+
"tool_calls": [_create_tool_call(tool, call_id)],
226+
}
227+
)
228+
229+
client.chat.completions.stream = Mock(side_effect=mock_stream)
230+
return client
231+
232+
233+
def _create_test_agent_config() -> AgentConfig:
234+
return AgentConfig(
235+
llm=LLMConfig(api_key="test-key", base_url="https://api.openai.com/v1", model="gpt-4o-mini"),
236+
prompts=PromptsConfig(
237+
system_prompt_str="Test system prompt",
238+
initial_user_request_str="Test initial request",
239+
clarification_response_str="Test clarification response",
240+
),
241+
execution=ExecutionConfig(max_iterations=10, max_clarifications=3, max_searches=5),
242+
)
243+
244+
245+
def _assert_agent_completed(agent, expected_result: str = "Final answer to the research task"):
246+
assert agent._context.state == AgentStatesEnum.COMPLETED
247+
assert agent._context.execution_result == expected_result
248+
assert agent._context.iteration >= 2
249+
assert len(agent.conversation) > 0
250+
assert len(agent.log) > 0
251+
252+
253+
@pytest.mark.asyncio
254+
async def test_sgr_agent_full_execution_cycle():
255+
agent = SGRAgent(
256+
task="Test research task",
257+
openai_client=create_mock_openai_client_for_sgr_agent(AdaptPlanTool, FinalAnswerTool),
258+
agent_config=_create_test_agent_config(),
259+
toolkit=[FinalAnswerTool, AdaptPlanTool],
260+
)
261+
262+
assert agent._context.state == AgentStatesEnum.INITED
263+
assert agent._context.iteration == 0
264+
265+
result = await agent.execute()
266+
267+
assert result is not None
268+
_assert_agent_completed(agent)
269+
270+
271+
@pytest.mark.asyncio
272+
async def test_tool_calling_agent_full_execution_cycle():
273+
agent = ToolCallingAgent(
274+
task="Test research task",
275+
openai_client=create_mock_openai_client_for_tool_calling_agent(AdaptPlanTool, FinalAnswerTool),
276+
agent_config=_create_test_agent_config(),
277+
toolkit=[FinalAnswerTool, AdaptPlanTool],
278+
)
279+
280+
assert agent._context.state == AgentStatesEnum.INITED
281+
assert agent._context.iteration == 0
282+
283+
result = await agent.execute()
284+
285+
assert result is not None
286+
_assert_agent_completed(agent)
287+
288+
289+
@pytest.mark.asyncio
290+
async def test_sgr_tool_calling_agent_full_execution_cycle():
291+
"""Validates that SGRToolCallingAgent overrides _prepare_tools()
292+
correctly."""
293+
agent = SGRToolCallingAgent(
294+
task="Test research task",
295+
openai_client=create_mock_openai_client_for_sgr_tool_calling_agent(AdaptPlanTool, FinalAnswerTool),
296+
agent_config=_create_test_agent_config(),
297+
toolkit=[FinalAnswerTool, AdaptPlanTool],
298+
)
299+
300+
assert agent._context.state == AgentStatesEnum.INITED
301+
assert agent._context.iteration == 0
302+
303+
result = await agent.execute()
304+
305+
assert result is not None
306+
_assert_agent_completed(agent)

0 commit comments

Comments
 (0)