Skip to content

Commit 9932d7e

Browse files
Merge pull request #171 from vamplabAI/some-obvious-thing
replaced less supported tool choice arg with more common one
2 parents 39f6eb3 + 0fe39ff commit 9932d7e

File tree

2 files changed

+77
-22
lines changed

2 files changed

+77
-22
lines changed

sgr_agent_core/agents/sgr_tool_calling_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def _reasoning_phase(self) -> ReasoningTool:
4141
async with self.openai_client.chat.completions.stream(
4242
messages=await self._prepare_context(),
4343
tools=[pydantic_function_tool(ReasoningTool, name=ReasoningTool.tool_name)],
44-
tool_choice={"type": "function", "function": {"name": ReasoningTool.tool_name}},
44+
tool_choice=self.tool_choice,
4545
**self.config.llm.to_openai_client_kwargs(),
4646
) as stream:
4747
async for event in stream:
@@ -66,7 +66,7 @@ async def _reasoning_phase(self) -> ReasoningTool:
6666
],
6767
}
6868
)
69-
tool_call_result = await reasoning(self._context)
69+
tool_call_result = await reasoning(self._context, self.config)
7070
self.streaming_generator.add_tool_call(
7171
f"{self._context.iteration}-reasoning", reasoning.tool_name, tool_call_result
7272
)

tests/test_agent_e2e.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,69 @@
1616

1717

1818
class MockStream:
19-
"""Mock OpenAI stream object."""
19+
"""Mock OpenAI stream object that emulates OpenAI streaming API.
20+
21+
This mock properly handles:
22+
- Context manager protocol (async with)
23+
- Stream iteration (async for event in stream)
24+
- Final completion retrieval with parsed_arguments support
25+
"""
2026

2127
def __init__(self, final_completion_data: dict):
28+
"""Initialize mock stream with final completion data.
29+
30+
Args:
31+
final_completion_data: Dictionary containing:
32+
- content: Optional message content
33+
- tool_calls: List of tool call objects (already with parsed_arguments set)
34+
"""
2235
self._final_completion_data = final_completion_data
36+
self._iterated = False
2337

2438
async def __aenter__(self):
39+
"""Enter context manager."""
2540
return self
2641

2742
async def __aexit__(self, exc_type, exc_val, exc_tb):
43+
"""Exit context manager."""
2844
pass
2945

3046
def __aiter__(self):
47+
"""Return iterator for stream events."""
3148
return self
3249

3350
async def __anext__(self):
51+
"""Return next stream event (empty iterator for simplicity).
52+
53+
In real OpenAI API, this would yield chunk events. For testing,
54+
we return empty iterator since the code handles missing chunks
55+
gracefully.
56+
"""
57+
if self._iterated:
58+
raise StopAsyncIteration
59+
self._iterated = True
3460
raise StopAsyncIteration
3561

3662
async def get_final_completion(self) -> ChatCompletion:
63+
"""Get final completion with parsed tool call arguments or structured
64+
output.
65+
66+
Supports both formats:
67+
- Structured output: message.parsed (for SGRAgent)
68+
- Function calling: tool_calls[0].function.parsed_arguments (for SGRToolCallingAgent)
69+
70+
Returns:
71+
ChatCompletion object with appropriate parsed data
72+
"""
73+
tool_calls = self._final_completion_data.get("tool_calls", [])
74+
3775
message = ChatCompletionMessage(
3876
role="assistant",
3977
content=self._final_completion_data.get("content"),
40-
tool_calls=self._final_completion_data.get("tool_calls"),
78+
tool_calls=tool_calls if tool_calls else None,
4179
)
80+
81+
# Support structured output format (SGRAgent uses message.parsed)
4282
if "parsed" in self._final_completion_data:
4383
setattr(message, "parsed", self._final_completion_data["parsed"])
4484

@@ -155,6 +195,15 @@ def mock_stream(**kwargs):
155195

156196

157197
def create_mock_openai_client_for_sgr_tool_calling_agent(action_tool_1: Type, action_tool_2: Type) -> AsyncOpenAI:
198+
"""Create a mock OpenAI client for SGRToolCallingAgent tests.
199+
200+
Args:
201+
action_tool_1: First action tool to return (e.g., AdaptPlanTool)
202+
action_tool_2: Second action tool to return (e.g., FinalAnswerTool)
203+
204+
Returns:
205+
Mocked AsyncOpenAI client configured for SGRToolCallingAgent execution cycle
206+
"""
158207
client = Mock(spec=AsyncOpenAI)
159208

160209
reasoning_tools = [
@@ -196,33 +245,39 @@ def create_mock_openai_client_for_sgr_tool_calling_agent(action_tool_1: Type, ac
196245
action_count = {"count": 0}
197246

198247
def mock_stream(**kwargs):
199-
is_reasoning = (
200-
"tool_choice" in kwargs
201-
and isinstance(kwargs.get("tool_choice"), dict)
202-
and kwargs["tool_choice"].get("function", {}).get("name") == ReasoningTool.tool_name
203-
)
204-
205-
if is_reasoning:
248+
"""Mock stream function that returns appropriate tool based on tool
249+
name."""
250+
tools_param = kwargs.get("tools", [])
251+
252+
# Validate that tools is a list
253+
if not isinstance(tools_param, list):
254+
raise TypeError(
255+
f"SGRToolCallingAgent._prepare_tools() must return a list, "
256+
f"but got {type(tools_param).__name__}. "
257+
f"Override _prepare_tools() to return list instead of NextStepToolStub."
258+
)
259+
260+
# Get tool name from first tool in the list
261+
tool_name = None
262+
if tools_param:
263+
first_tool = tools_param[0]
264+
if isinstance(first_tool, dict):
265+
tool_name = first_tool.get("function", {}).get("name")
266+
267+
# Return appropriate tool based on name
268+
if tool_name == ReasoningTool.tool_name:
206269
reasoning_count["count"] += 1
207270
tool = reasoning_tools[reasoning_count["count"] - 1]
208-
call_id = f"reasoning_{reasoning_count['count']}"
209271
else:
210-
tools_param = kwargs.get("tools")
211-
if tools_param is not None and not isinstance(tools_param, list):
212-
raise TypeError(
213-
f"SGRToolCallingAgent._prepare_tools() must return a list, "
214-
f"but got {type(tools_param).__name__}. "
215-
f"Override _prepare_tools() to return list instead of NextStepToolStub."
216-
)
272+
# Action tool - use counter to select from action_tools list
217273
action_count["count"] += 1
218274
tool = action_tools[action_count["count"] - 1]
219-
call_id = f"action_{action_count['count']}"
220275

276+
# call_id is not used by agent, just needed for valid OpenAI API structure
221277
return MockStream(
222278
final_completion_data={
223279
"content": None,
224-
"role": "assistant",
225-
"tool_calls": [_create_tool_call(tool, call_id)],
280+
"tool_calls": [_create_tool_call(tool, "mock-call-id")],
226281
}
227282
)
228283

0 commit comments

Comments
 (0)