|
16 | 16 |
|
17 | 17 |
|
18 | 18 | class MockStream: |
19 | | - """Mock OpenAI stream object.""" |
| 19 | + """Mock OpenAI stream object that emulates OpenAI streaming API. |
| 20 | +
|
| 21 | + This mock properly handles: |
| 22 | + - Context manager protocol (async with) |
| 23 | + - Stream iteration (async for event in stream) |
| 24 | + - Final completion retrieval with parsed_arguments support |
| 25 | + """ |
20 | 26 |
|
21 | 27 | def __init__(self, final_completion_data: dict): |
| 28 | + """Initialize mock stream with final completion data. |
| 29 | +
|
| 30 | + Args: |
| 31 | + final_completion_data: Dictionary containing: |
| 32 | + - content: Optional message content |
| 33 | + - tool_calls: List of tool call objects (already with parsed_arguments set) |
| 34 | + """ |
22 | 35 | self._final_completion_data = final_completion_data |
| 36 | + self._iterated = False |
23 | 37 |
|
24 | 38 | async def __aenter__(self): |
| 39 | + """Enter context manager.""" |
25 | 40 | return self |
26 | 41 |
|
27 | 42 | async def __aexit__(self, exc_type, exc_val, exc_tb): |
| 43 | + """Exit context manager.""" |
28 | 44 | pass |
29 | 45 |
|
30 | 46 | def __aiter__(self): |
| 47 | + """Return iterator for stream events.""" |
31 | 48 | return self |
32 | 49 |
|
33 | 50 | async def __anext__(self): |
| 51 | + """Return next stream event (empty iterator for simplicity). |
| 52 | +
|
| 53 | + In real OpenAI API, this would yield chunk events. For testing, |
| 54 | + we return empty iterator since the code handles missing chunks |
| 55 | + gracefully. |
| 56 | + """ |
| 57 | + if self._iterated: |
| 58 | + raise StopAsyncIteration |
| 59 | + self._iterated = True |
34 | 60 | raise StopAsyncIteration |
35 | 61 |
|
36 | 62 | async def get_final_completion(self) -> ChatCompletion: |
| 63 | + """Get final completion with parsed tool call arguments or structured |
| 64 | + output. |
| 65 | +
|
| 66 | + Supports both formats: |
| 67 | + - Structured output: message.parsed (for SGRAgent) |
| 68 | + - Function calling: tool_calls[0].function.parsed_arguments (for SGRToolCallingAgent) |
| 69 | +
|
| 70 | + Returns: |
| 71 | + ChatCompletion object with appropriate parsed data |
| 72 | + """ |
| 73 | + tool_calls = self._final_completion_data.get("tool_calls", []) |
| 74 | + |
37 | 75 | message = ChatCompletionMessage( |
38 | 76 | role="assistant", |
39 | 77 | content=self._final_completion_data.get("content"), |
40 | | - tool_calls=self._final_completion_data.get("tool_calls"), |
| 78 | + tool_calls=tool_calls if tool_calls else None, |
41 | 79 | ) |
| 80 | + |
| 81 | + # Support structured output format (SGRAgent uses message.parsed) |
42 | 82 | if "parsed" in self._final_completion_data: |
43 | 83 | setattr(message, "parsed", self._final_completion_data["parsed"]) |
44 | 84 |
|
@@ -155,6 +195,15 @@ def mock_stream(**kwargs): |
155 | 195 |
|
156 | 196 |
|
157 | 197 | def create_mock_openai_client_for_sgr_tool_calling_agent(action_tool_1: Type, action_tool_2: Type) -> AsyncOpenAI: |
| 198 | + """Create a mock OpenAI client for SGRToolCallingAgent tests. |
| 199 | +
|
| 200 | + Args: |
| 201 | + action_tool_1: First action tool to return (e.g., AdaptPlanTool) |
| 202 | + action_tool_2: Second action tool to return (e.g., FinalAnswerTool) |
| 203 | +
|
| 204 | + Returns: |
| 205 | + Mocked AsyncOpenAI client configured for SGRToolCallingAgent execution cycle |
| 206 | + """ |
158 | 207 | client = Mock(spec=AsyncOpenAI) |
159 | 208 |
|
160 | 209 | reasoning_tools = [ |
@@ -196,33 +245,39 @@ def create_mock_openai_client_for_sgr_tool_calling_agent(action_tool_1: Type, ac |
196 | 245 | action_count = {"count": 0} |
197 | 246 |
|
198 | 247 | def mock_stream(**kwargs): |
199 | | - is_reasoning = ( |
200 | | - "tool_choice" in kwargs |
201 | | - and isinstance(kwargs.get("tool_choice"), dict) |
202 | | - and kwargs["tool_choice"].get("function", {}).get("name") == ReasoningTool.tool_name |
203 | | - ) |
204 | | - |
205 | | - if is_reasoning: |
| 248 | + """Mock stream function that returns appropriate tool based on tool |
| 249 | + name.""" |
| 250 | + tools_param = kwargs.get("tools", []) |
| 251 | + |
| 252 | + # Validate that tools is a list |
| 253 | + if not isinstance(tools_param, list): |
| 254 | + raise TypeError( |
| 255 | + f"SGRToolCallingAgent._prepare_tools() must return a list, " |
| 256 | + f"but got {type(tools_param).__name__}. " |
| 257 | + f"Override _prepare_tools() to return list instead of NextStepToolStub." |
| 258 | + ) |
| 259 | + |
| 260 | + # Get tool name from first tool in the list |
| 261 | + tool_name = None |
| 262 | + if tools_param: |
| 263 | + first_tool = tools_param[0] |
| 264 | + if isinstance(first_tool, dict): |
| 265 | + tool_name = first_tool.get("function", {}).get("name") |
| 266 | + |
| 267 | + # Return appropriate tool based on name |
| 268 | + if tool_name == ReasoningTool.tool_name: |
206 | 269 | reasoning_count["count"] += 1 |
207 | 270 | tool = reasoning_tools[reasoning_count["count"] - 1] |
208 | | - call_id = f"reasoning_{reasoning_count['count']}" |
209 | 271 | else: |
210 | | - tools_param = kwargs.get("tools") |
211 | | - if tools_param is not None and not isinstance(tools_param, list): |
212 | | - raise TypeError( |
213 | | - f"SGRToolCallingAgent._prepare_tools() must return a list, " |
214 | | - f"but got {type(tools_param).__name__}. " |
215 | | - f"Override _prepare_tools() to return list instead of NextStepToolStub." |
216 | | - ) |
| 272 | + # Action tool - use counter to select from action_tools list |
217 | 273 | action_count["count"] += 1 |
218 | 274 | tool = action_tools[action_count["count"] - 1] |
219 | | - call_id = f"action_{action_count['count']}" |
220 | 275 |
|
| 276 | + # call_id is not used by agent, just needed for valid OpenAI API structure |
221 | 277 | return MockStream( |
222 | 278 | final_completion_data={ |
223 | 279 | "content": None, |
224 | | - "role": "assistant", |
225 | | - "tool_calls": [_create_tool_call(tool, call_id)], |
| 280 | + "tool_calls": [_create_tool_call(tool, "mock-call-id")], |
226 | 281 | } |
227 | 282 | ) |
228 | 283 |
|
|
0 commit comments