diff --git a/sgr_agent_core/agents/__init__.py b/sgr_agent_core/agents/__init__.py index ef00bc1b..3c8e9bfd 100644 --- a/sgr_agent_core/agents/__init__.py +++ b/sgr_agent_core/agents/__init__.py @@ -1,10 +1,12 @@ """Agents module for SGR Agent Core.""" +from sgr_agent_core.agents.iron_agent import IronAgent from sgr_agent_core.agents.sgr_agent import SGRAgent from sgr_agent_core.agents.sgr_tool_calling_agent import SGRToolCallingAgent from sgr_agent_core.agents.tool_calling_agent import ToolCallingAgent __all__ = [ + "IronAgent", "SGRAgent", "SGRToolCallingAgent", "ToolCallingAgent", diff --git a/sgr_agent_core/agents/iron_agent.py b/sgr_agent_core/agents/iron_agent.py new file mode 100644 index 00000000..4270ec1f --- /dev/null +++ b/sgr_agent_core/agents/iron_agent.py @@ -0,0 +1,229 @@ +"""Fuzzy Agent that doesn't rely on tool calling or structured output.""" + +from datetime import datetime +from typing import Type + +from openai import AsyncOpenAI +from pydantic import BaseModel + +from sgr_agent_core.agent_config import AgentConfig +from sgr_agent_core.base_agent import BaseAgent +from sgr_agent_core.next_step_tool import NextStepToolsBuilder +from sgr_agent_core.services.registry import ToolRegistry +from sgr_agent_core.services.tool_instantiator import ToolInstantiator +from sgr_agent_core.tools import BaseTool, ReasoningTool, ToolNameSelectorStub + + +class IronAgent(BaseAgent): + """Agent that uses flexible parsing of LLM text responses instead of tool + calling. + + This agent doesn't rely on: + - Tool calling (function calling) + - Structured output (response_format) + + Instead, it parses natural language responses from LLM to determine + which tool to use and with what parameters using ToolInstantiator. + """ + + name: str = "iron_agent" + + def __init__( + self, + task_messages: list, + openai_client: AsyncOpenAI, + agent_config: AgentConfig, + toolkit: list[Type[BaseTool]], + def_name: str | None = None, + **kwargs: dict, + ): + super().__init__( + task_messages=task_messages, + openai_client=openai_client, + agent_config=agent_config, + toolkit=toolkit, + def_name=def_name, + **kwargs, + ) + + def _log_tool_instantiator( + self, + instantiator: ToolInstantiator, + attempt: int, + max_retries: int, + ): + """Log tool generation attempt by LLM using data from ToolInstantiator. + + Args: + instantiator: ToolInstantiator instance with attempt data + attempt: Current attempt number (1-based) + max_retries: Maximum number of retry attempts + """ + success = instantiator.instance is not None + errors_formatted = instantiator.input_content + "\n" + "\n".join(instantiator.errors) + self.logger.info( + f""" +############################################### +TOOL GENERATION DEBUG + {"✅" if success else "❌"} ATTEMPT {attempt}/{max_retries} {"SUCCESS" if success else "FAILED"}: + + Tool: {instantiator.tool_class.tool_name} - Class: {instantiator.tool_class.__name__} + +{errors_formatted if not success else ""} +###############################################""" + ) + + self.log.append( + { + "step_number": self._context.iteration, + "timestamp": datetime.now().isoformat(), + "step_type": "tool_generation_attempt", + "tool_class": instantiator.tool_class.__name__, + "attempt": attempt, + "max_retries": max_retries, + "success": success, + "llm_content": instantiator.input_content, + "errors": instantiator.errors.copy(), + } + ) + + async def _generate_tool( + self, + tool_class: Type[BaseTool], + messages: list[dict], + max_retries: int = 5, + ) -> BaseTool | BaseModel: + """Generate tool instance from LLM response using ToolInstantiator. + + Universal method for calling LLM with parsing through ToolInstantiator. + Handles retries with error accumulation. + + Args: + tool_class: Tool class or model class to instantiate + messages: Context messages for LLM + max_retries: Maximum number of retry attempts + + Returns: + Instance of tool_class + + Raises: + ValueError: If parsing fails after max_retries attempts + """ + instantiator = ToolInstantiator(tool_class) + + for attempt in range(max_retries): + async with self.openai_client.chat.completions.stream( + messages=messages + [{"role": "user", "content": instantiator.generate_format_prompt()}], + **self.config.llm.to_openai_client_kwargs(), + ) as stream: + async for event in stream: + if event.type == "chunk": + self.streaming_generator.add_chunk(event.chunk) + + completion = await stream.get_final_completion() + content = completion.choices[0].message.content + try: + tool_instance = instantiator.build_model(content) + return tool_instance + except ValueError: + continue + finally: + self._log_tool_instantiator( + instantiator=instantiator, + attempt=attempt + 1, + max_retries=max_retries, + ) + + raise ValueError( + f"Failed to parse {tool_class.__name__} after {max_retries} attempts. " + f"Try to simplify tool schema or provide more detailed instructions." + ) + + async def _prepare_tools(self) -> Type[ToolNameSelectorStub]: + """Prepare available tools for the current agent state and progress.""" + if self._context.iteration >= self.config.execution.max_iterations: + raise RuntimeError("Max iterations reached") + return NextStepToolsBuilder.build_NextStepToolSelector(self.toolkit) + + async def _reasoning_phase(self) -> ReasoningTool: + """Call LLM to get ReasoningTool with selected tool name.""" + messages = await self._prepare_context() + + tool_selector_model = await self._prepare_tools() + reasoning = await self._generate_tool(tool_selector_model, messages) + + if not isinstance(reasoning, ReasoningTool): + raise ValueError("Expected ReasoningTool instance") + + # Log reasoning + self._log_reasoning(reasoning) + + # Add to streaming + self.streaming_generator.add_tool_call( + f"{self._context.iteration}-reasoning", + reasoning.tool_name, + reasoning.model_dump_json(exclude={"function_name_choice"}), + ) + + return reasoning + + async def _select_action_phase(self, reasoning: ReasoningTool) -> BaseTool: + """Select tool based on reasoning phase result.""" + messages = await self._prepare_context() + + tool_name = reasoning.function_name_choice # type: ignore + + # Find tool class by name + tool_class: Type[BaseTool] | None = None + + # Try ToolRegistry first + tool_class = ToolRegistry.get(tool_name) + + # If not found, search in toolkit + if tool_class is None: + for tool in self.toolkit: + if tool.tool_name == tool_name: + tool_class = tool + break + + if tool_class is None: + raise ValueError(f"Tool '{tool_name}' not found in toolkit") + + # Generate tool parameters + tool = await self._generate_tool(tool_class, messages) + + if not isinstance(tool, BaseTool): + raise ValueError("Selected tool is not a valid BaseTool instance") + + # Add to conversation + self.conversation.append( + { + "role": "assistant", + "content": reasoning.remaining_steps[0] if reasoning.remaining_steps else "Completing", + "tool_calls": [ + { + "type": "function", + "id": f"{self._context.iteration}-action", + "function": { + "name": tool.tool_name, + "arguments": tool.model_dump_json(), + }, + } + ], + } + ) + self.streaming_generator.add_tool_call( + f"{self._context.iteration}-action", tool.tool_name, tool.model_dump_json() + ) + + return tool + + async def _action_phase(self, tool: BaseTool) -> str: + """Execute selected tool.""" + result = await tool(self._context, self.config) + self.conversation.append( + {"role": "tool", "content": result, "tool_call_id": f"{self._context.iteration}-action"} + ) + self.streaming_generator.add_chunk_from_str(f"{result}\n") + self._log_tool_execution(tool, result) + return result diff --git a/sgr_agent_core/next_step_tool.py b/sgr_agent_core/next_step_tool.py index af855a9d..3790b8f3 100644 --- a/sgr_agent_core/next_step_tool.py +++ b/sgr_agent_core/next_step_tool.py @@ -23,6 +23,17 @@ class NextStepToolStub(ReasoningTool, ABC): function: T = Field(description="Select the appropriate tool for the next step") +class ToolNameSelectorStub(ReasoningTool, ABC): + """Stub class for tool name selection that inherits from ReasoningTool. + + Used by IronAgent to select tool name as part of reasoning phase. + (!) Stub class for correct autocomplete. Use + NextStepToolsBuilder.build_NextStepToolSelector + """ + + function_name_choice: str = Field(description="Select the name of the tool to use") + + class DiscriminantToolMixin(BaseModel): tool_name_discriminator: str = Field(..., description="Tool name discriminator") @@ -60,8 +71,35 @@ def _create_tool_types_union(cls, tools_list: list[Type[T]]) -> Type: @classmethod def build_NextStepTools(cls, tools_list: list[Type[T]]) -> Type[NextStepToolStub]: # noqa + """Build a model with all NextStepTool args.""" return create_model( "NextStepTools", __base__=NextStepToolStub, - function=(cls._create_tool_types_union(tools_list), Field()), + function=( + cls._create_tool_types_union(tools_list), + Field(description="Select and fill parameters of the appropriate tool for the next step"), + ), + ) + + @classmethod + def build_NextStepToolSelector(cls, tools_list: list[Type[T]]) -> Type[ToolNameSelectorStub]: + """Build a model for selecting tool name.""" + # Extract tool names and descriptions + tool_names = [tool.tool_name for tool in tools_list] + + if len(tool_names) == 1: + literal_type = Literal[tool_names[0]] + else: + # Create union of individual Literal types using operator.or_ + # Literal["a"] | Literal["b"] is equivalent to Literal["a", "b"] + literal_types = [Literal[name] for name in tool_names] + literal_type = reduce(operator.or_, literal_types) + + # Create model dynamically, inheriting from ToolNameSelectorStub (which inherits from ReasoningTool) + model_class = create_model( + "NextStepToolSelector", + __base__=ToolNameSelectorStub, + function_name_choice=(literal_type, Field(description="Choose the name for the best tool to use")), ) + model_class.tool_name = "nextsteptoolselector" # type: ignore + return model_class diff --git a/sgr_agent_core/services/__init__.py b/sgr_agent_core/services/__init__.py index 7ff92bf0..081b3680 100644 --- a/sgr_agent_core/services/__init__.py +++ b/sgr_agent_core/services/__init__.py @@ -4,6 +4,7 @@ from sgr_agent_core.services.prompt_loader import PromptLoader from sgr_agent_core.services.registry import AgentRegistry, ToolRegistry from sgr_agent_core.services.tavily_search import TavilySearchService +from sgr_agent_core.services.tool_instantiator import ToolInstantiator __all__ = [ "TavilySearchService", @@ -11,4 +12,5 @@ "ToolRegistry", "AgentRegistry", "PromptLoader", + "ToolInstantiator", ] diff --git a/sgr_agent_core/services/tool_instantiator.py b/sgr_agent_core/services/tool_instantiator.py new file mode 100644 index 00000000..25bfb243 --- /dev/null +++ b/sgr_agent_core/services/tool_instantiator.py @@ -0,0 +1,404 @@ +import json +from json import JSONDecodeError +from typing import TYPE_CHECKING + +from pydantic import ValidationError + +if TYPE_CHECKING: + from sgr_agent_core.base_tool import BaseTool + + +class SchemaSimplifier: + """Converts JSON schemas to structured text format.""" + + @classmethod + def simplify(cls, schema: dict, indent: int = 0) -> str: + """Convert JSON schema to structured text format (variant 3 - compact inline). + + Format: - field_name (required/optional, тип, ограничения): описание + + Args: + schema: JSON schema dictionary from model_json_schema() + indent: Current indentation level for nested structures + + Returns: + Formatted text representation of schema + """ + properties = schema.get("properties", {}) + required = schema.get("required", []) + defs = schema.get("$defs", {}) + + if not properties: + return "" + + required_fields = [(name, props) for name, props in properties.items() if name in required] + optional_fields = [(name, props) for name, props in properties.items() if name not in required] + sorted_fields = required_fields + optional_fields + + result_lines = [] + indent_str = " " * indent + + for field_name, field_schema in sorted_fields: + is_required = field_name in required + req_text = "required" if is_required else "optional" + + field_type = cls._extract_type(field_schema, defs) + constraints = cls._extract_constraints(field_schema) + description = field_schema.get("description", "") + + # Format line: - field_name (required/optional, type, restrictions): description + type_constraints = ", ".join([field_type, *constraints]) + line = f"{indent_str}- {field_name} ({req_text}, {type_constraints}): {description}" + + result_lines.append(line) + + # Handle anyOf with $ref to $defs - expand nested schemas + if "anyOf" in field_schema and defs: + for variant in field_schema["anyOf"]: + if "$ref" in variant: + ref_path = variant["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.replace("#/$defs/", "") + if def_name in defs: + def_schema = defs[def_name] + # Get schema name from title or def_name + schema_name = def_schema.get("title", def_name) + # Add schema name as header + nested_indent_str = " " * (indent + 1) + result_lines.append(f"{nested_indent_str}Variant: {schema_name}") + # Recursively simplify the schema from $defs + nested = cls.simplify(def_schema, indent + 2) + if nested: + result_lines.append(nested) + + # Handle nested objects recursively + if field_schema.get("type") == "object" and "properties" in field_schema: + nested = cls.simplify(field_schema, indent + 1) + if nested: + result_lines.append(nested) + + return "\n".join(result_lines) + + @classmethod + def _extract_type(cls, field_schema: dict, defs: dict | None = None) -> str: + """Extract type string from field schema. + + Args: + field_schema: Field schema dictionary + defs: Definitions dictionary for resolving $ref (optional) + + Returns: + Type string representation + """ + if defs is None: + defs = {} + # Handle const + if "const" in field_schema: + base_type = field_schema.get("type", "string") + const_value = field_schema["const"] + return f"{base_type} (const: {json.dumps(const_value)})" + + # Handle enum (Literal) + if "enum" in field_schema: + enum_values = field_schema["enum"] + enum_str = ", ".join(json.dumps(v) for v in enum_values) + return f"Literal[{enum_str}]" + + # Handle anyOf (Union) + if "anyOf" in field_schema: + types = [] + const_values = [] + for variant in field_schema["anyOf"]: + if "$ref" in variant: + # Resolve $ref + ref_path = variant["$ref"] + if ref_path.startswith("#/$defs/") and defs: + def_name = ref_path.replace("#/$defs/", "") + if def_name in defs: + # Use title or def_name from schema + def_schema = defs[def_name] + type_name = def_schema.get("title", def_name) + types.append(type_name) + else: + # Use def_name as-is if not in defs + types.append(def_name) + else: + types.append(ref_path) + elif "const" in variant: + # Collect const values + const_values.append(variant["const"]) + else: + # Extract type from variant + variant_type = cls._extract_type(variant, defs) + types.append(variant_type) + + # If all variants are const values, format as const(value1 OR value2 OR value3) + if const_values and not types: + const_str = " OR ".join(json.dumps(v) for v in const_values) + return f"const({const_str})" + + # Wrap Union types in Literal format + if types: + types_str = ", ".join(types) + return f"Literal[{types_str}]" + + return "unknown" + + # Handle array + if field_schema.get("type") == "array": + items = field_schema.get("items", {}) + if isinstance(items, dict): + element_type = cls._extract_type(items, defs) + return f"list[{element_type}]" + return "list" + + # Handle object + if field_schema.get("type") == "object": + if "properties" in field_schema: + return "object" + return "object" + + # Simple types + field_type = field_schema.get("type", "unknown") + return field_type + + @classmethod + def _extract_constraints(cls, field_schema: dict) -> list[str]: + """Extract constraints string from field schema. + + Args: + field_schema: Field schema dictionary + + Returns: + Constraints string (empty if no constraints) + """ + constraints = [] + + # Default value + if "default" in field_schema: + default_val = field_schema["default"] + constraints.append(f"default: {json.dumps(default_val)}") + + # Numeric range + if "minimum" in field_schema and "maximum" in field_schema: + min_val = field_schema["minimum"] + max_val = field_schema["maximum"] + constraints.append(f"range: {min_val}-{max_val}") + elif "minimum" in field_schema: + constraints.append(f"min: {field_schema['minimum']}") + elif "maximum" in field_schema: + constraints.append(f"max: {field_schema['maximum']}") + + # String length + if "minLength" in field_schema and "maxLength" in field_schema: + min_len = field_schema["minLength"] + max_len = field_schema["maxLength"] + constraints.append(f"length: {min_len}-{max_len}") + elif "minLength" in field_schema: + constraints.append(f"min length: {field_schema['minLength']}") + elif "maxLength" in field_schema: + constraints.append(f"max length: {field_schema['maxLength']}") + + # Array items + if "minItems" in field_schema and "maxItems" in field_schema: + min_items = field_schema["minItems"] + max_items = field_schema["maxItems"] + constraints.append(f"{min_items}-{max_items} items") + elif "minItems" in field_schema: + constraints.append(f"min {field_schema['minItems']} items") + elif "maxItems" in field_schema: + constraints.append(f"max {field_schema['maxItems']} items") + + return constraints + + +# YESS im trying in SO inside a framework +class ToolInstantiator: + """Structured Output like Service for formatting, parsing raw LLM context + and building tool instances.""" + + _FORMAT_PROMPT_TEMPLATE = ( + "CRITICAL: Generate a JSON OBJECT with actual data values, NOT the schema definition.\n\n" + "================================================================================\n" + "HOW TO READ THE SCHEMA (TEXT FORMAT):\n" + "================================================================================\n\n" + " FORMAT: - field_name (required/optional, type, constraints): description\n\n" + " EXAMPLES:\n" + " - name (required, string, max length: 100): User's full name\n" + " - age (optional, integer, range: 18-120, default: 25): User's age\n" + " - tags (required, list[string], 1-5 items): List of tags\n" + ' - status (required, const("active" OR "inactive")): Account status\n' + " - tool (required, Literal[ToolA, ToolB, ToolC]): Select one tool\n" + " Variant: ToolA\n" + " - param1 (required, string): Parameter 1\n" + " - param2 (optional, integer): Parameter 2\n" + " Variant: ToolB\n" + " - param3 (required, boolean): Parameter 3\n\n" + " KEY CONCEPTS:\n" + " - required/optional: Whether the field must be provided\n" + " - type: Data type (string, integer, boolean, list[type], Literal[...], etc.)\n" + " - constraints: Limits like 'range: 1-10', 'max length: 300', '2-3 items', 'default: value'\n" + " - Literal[...]: Union type - choose ONE of the listed options\n" + " - Variant: sections: For nested unions, select appropriate variant and fill its fields\n" + " - const(value): Field must have exactly this value (no other options)\n\n" + "HOW TO GENERATE VALUES:\n" + "- Use the conversation context above to understand what the user wants\n" + "- Use the tool description to understand what this tool does\n" + "- Fill each field with appropriate values based on its schema description, the context and tool purpose\n" + "- Match the types and constraints shown in the schema\n" + "- For Union types, choose the appropriate variant based on context\n\n" + "REQUIREMENTS:\n" + "- Output ONLY raw JSON object (no markdown, no code blocks, no explanations)\n" + "- Fill fields with actual values matching the schema types\n" + "- All strings must be quoted, arrays in [], objects in {}\n" + "- Respect all constraints (min/max values, lengths, item counts, etc.)\n\n" + "\n" + '{"name": "John Doe", "age": 30, "tags": ["tag1", "tag2"], ' + '"status": "active", "tool": "ToolA", "param1": "value1", "param2": 42}\n' + "\n\n" + "\n" + '{"name": {"type": "string", "description": "User\'s full name"}, ' + '"age": {"type": "integer", "range": "18-120"}}\n' + "\n\n" + "\n" + '```json\n{"name": "John Doe", "age": 30}\n```\n' + "\n\n" + "The schema below shows the STRUCTURE in text format - provide VALUES in JSON that match it.\n\n" + ) + + def __init__(self, tool_class: type["BaseTool"]): + """Initialize tool instantiator. + + Args: + tool_class: Tool class to instantiate + """ + self.tool_class = tool_class + self.errors: list[str] = [] + self.instance: "BaseTool | None" = None + self.input_content: str = "" + + def _clearing_context(self, content: str) -> str: + """Extract JSON object from content by finding first { and last }. + + Args: + content: Raw content that may contain JSON mixed with other text + + Returns: + Extracted JSON string (from first { to last }) + """ + first_brace = content.find("{") + if first_brace == -1: + return content + + last_brace = content.rfind("}") + if last_brace == -1 or last_brace < first_brace: + return content + + return content[first_brace : last_brace + 1] + + def _format_json_error(self, error: JSONDecodeError, content: str) -> str: + """Format JSON decode error with context around the error position. + + Args: + error: JSONDecodeError exception + content: Content that failed to parse + + Returns: + Formatted error message with context + """ + error_pos = getattr(error, "pos", None) + if error_pos is None: + return f"Failed to parse JSON: {error}" + + # Calculate context window (±20 characters) + context_size = 20 + start = max(0, error_pos - context_size) + end = min(len(content), error_pos + context_size) + + context_before = content[start:error_pos] + context_after = content[error_pos:end] + + # Show position and context + error_msg = ( + f"JSON parse error at position {error_pos}: {error.msg}\n" + f"Context: ...{context_before}[Error here]{context_after}..." + ) + + return error_msg + + def generate_format_prompt( + self, + include_errors: bool = True, + ) -> str: + """Generate a prompt describing the expected format for LLM. + + Args: + mode: Which parameters to include: + - "all": Show all parameters (default) + - "required": Show only required parameters + - "unfilled": Show only parameters where value is PydanticUndefined + include_errors: If True, include error messages for parameters with validation errors + + Returns: + Format description string + """ + simplified_schema = SchemaSimplifier.simplify(self.tool_class.model_json_schema()) + + prompt = ( + f"\n" + f"Tool name: {self.tool_class.tool_name}\n" + f"Tool description: {self.tool_class.description}\n" + f"\n\n" + f"{self._FORMAT_PROMPT_TEMPLATE}" + f"\n" + f"{simplified_schema}\n" + f"\n\n" + ) + + # Show errors at the end if requested + if include_errors and self.errors: + prompt += "PREVIOUS FILLING ITERATION ERRORS. AVOID OR FIX IT:\n" + prompt += "fault content: \n" + self.input_content + "\n" + for error in self.errors: + prompt += f" - {error}\n" + + return prompt + + def build_model(self, content: str) -> "BaseTool": + """Build tool model instance from parsed parameters. + + Args: + content: Raw content from LLM to parse + + Returns: + Tool instance + + Raises: + ValueError: If required parameters are missing or model creation fails + """ + self.content = "" + self.errors.clear() + if not content: + self.errors.append("No content provided") + raise ValueError("No content provided") + self.input_content = content + + try: + cleaned_content = self._clearing_context(content) + self.content = cleaned_content # Save cleaned content for debugging + self.instance = self.tool_class(**json.loads(cleaned_content)) + return self.instance + except ValidationError as e: + for err in e.errors(): + self.errors.append( + f"pydantic validation error - type: {err['type']} - " f"field: {err['loc']} - {err['msg']}" + ) + raise ValueError("Failed to build model") from e + except JSONDecodeError as e: + error_content = cleaned_content if "cleaned_content" in locals() else self.content or content + error_msg = self._format_json_error(e, error_content) + self.errors.append(error_msg) + raise ValueError("Failed to build model") from e + except ValueError as e: + self.errors.append(f"Failed to parse JSON: {e}") + raise ValueError("Failed to build model") from e diff --git a/sgr_agent_core/tools/__init__.py b/sgr_agent_core/tools/__init__.py index ef859ffa..3f93988f 100644 --- a/sgr_agent_core/tools/__init__.py +++ b/sgr_agent_core/tools/__init__.py @@ -1,5 +1,9 @@ from sgr_agent_core.base_tool import BaseTool, MCPBaseTool -from sgr_agent_core.next_step_tool import NextStepToolsBuilder, NextStepToolStub +from sgr_agent_core.next_step_tool import ( + NextStepToolsBuilder, + NextStepToolStub, + ToolNameSelectorStub, +) from sgr_agent_core.tools.adapt_plan_tool import AdaptPlanTool from sgr_agent_core.tools.clarification_tool import ClarificationTool from sgr_agent_core.tools.create_report_tool import CreateReportTool @@ -14,6 +18,7 @@ "BaseTool", "MCPBaseTool", "NextStepToolStub", + "ToolNameSelectorStub", "NextStepToolsBuilder", # Individual tools "ClarificationTool", diff --git a/sgr_agent_core/tools/reasoning_tool.py b/sgr_agent_core/tools/reasoning_tool.py index 5296d672..80f0e837 100644 --- a/sgr_agent_core/tools/reasoning_tool.py +++ b/sgr_agent_core/tools/reasoning_tool.py @@ -36,8 +36,7 @@ class ReasoningTool(BaseTool): # Next step planning remaining_steps: list[str] = Field( - description="1-3 remaining steps (brief, action-oriented)", - min_length=1, + description="0-3 remaining steps (brief, action-oriented)", max_length=3, ) task_completed: bool = Field(description="Is the research task finished?") diff --git a/tests/test_iron_agent.py b/tests/test_iron_agent.py new file mode 100644 index 00000000..1adbfee2 --- /dev/null +++ b/tests/test_iron_agent.py @@ -0,0 +1,331 @@ +"""Tests for IronAgent.""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage + +from sgr_agent_core.agents.iron_agent import IronAgent +from sgr_agent_core.tools import ReasoningTool, WebSearchTool +from tests.conftest import create_test_agent + + +class TestIronAgent: + """Tests for IronAgent.""" + + def test_initialization(self): + """Test IronAgent initialization.""" + agent = create_test_agent( + IronAgent, + task_messages=[{"role": "user", "content": "Test task"}], + toolkit=[ReasoningTool, WebSearchTool], + ) + + assert agent.name == "iron_agent" + assert len(agent.toolkit) == 2 + assert ReasoningTool in agent.toolkit + assert WebSearchTool in agent.toolkit + + @pytest.mark.asyncio + async def test_generate_tool_success(self): + """Test _generate_tool with successful parsing.""" + agent = create_test_agent( + IronAgent, + toolkit=[ReasoningTool], + ) + + # Mock stream response + mock_chunk = Mock(spec=ChatCompletionChunk) + mock_chunk.type = "chunk" + mock_chunk.chunk = Mock() # Add chunk attribute + + mock_completion = Mock(spec=ChatCompletionMessage) + mock_completion.content = ( + '{"reasoning_steps": ["step1", "step2"], "current_situation": "test", ' + '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], ' + '"task_completed": false}' + ) + + # Create async iterator for stream + async def async_iter(): + yield mock_chunk + + mock_stream = AsyncMock() + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + mock_stream.__aiter__ = lambda self: async_iter() + mock_final_completion = Mock() + mock_final_completion.choices = [Mock(message=mock_completion)] + mock_stream.get_final_completion = AsyncMock(return_value=mock_final_completion) + + agent.openai_client.chat.completions.stream = Mock(return_value=mock_stream) + + # Call method + result = await agent._generate_tool(ReasoningTool, [{"role": "user", "content": "test"}]) + + assert isinstance(result, ReasoningTool) + assert result.reasoning_steps == ["step1", "step2"] + + # Check that logging was called for successful attempt + log_entries = [entry for entry in agent.log if entry.get("step_type") == "tool_generation_attempt"] + assert len(log_entries) == 1 + assert log_entries[0]["success"] is True + assert log_entries[0]["tool_class"] == "ReasoningTool" + assert log_entries[0]["llm_content"] == mock_completion.content + + @pytest.mark.asyncio + async def test_generate_tool_retry_on_error(self): + """Test _generate_tool with retry on parsing error.""" + agent = create_test_agent( + IronAgent, + toolkit=[ReasoningTool], + ) + + # Mock stream response with invalid JSON first, then valid + mock_chunk = Mock(spec=ChatCompletionChunk) + mock_chunk.type = "chunk" + mock_chunk.chunk = Mock() # Add chunk attribute + + # First attempt - invalid JSON + mock_completion_invalid = Mock(spec=ChatCompletionMessage) + mock_completion_invalid.content = "invalid json" + + # Second attempt - valid JSON + mock_completion_valid = Mock(spec=ChatCompletionMessage) + mock_completion_valid.content = ( + '{"reasoning_steps": ["step1", "step2"], "current_situation": "test", ' + '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], ' + '"task_completed": false}' + ) + + # Create async iterator for stream + async def async_iter(): + yield mock_chunk + + # First call returns invalid, second returns valid + call_count = {"count": 0} + + async def get_completion(): + call_count["count"] += 1 + if call_count["count"] == 1: + return Mock(choices=[Mock(message=mock_completion_invalid)]) + return Mock(choices=[Mock(message=mock_completion_valid)]) + + mock_stream = AsyncMock() + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + mock_stream.__aiter__ = lambda self: async_iter() + mock_stream.get_final_completion = AsyncMock(side_effect=get_completion) + + agent.openai_client.chat.completions.stream = Mock(return_value=mock_stream) + + # Call method - should succeed on second attempt + result = await agent._generate_tool(ReasoningTool, [{"role": "user", "content": "test"}], max_retries=5) + + assert isinstance(result, ReasoningTool) + assert call_count["count"] == 2 # Two attempts made + + # Check that logging was called for both attempts + # Note: due to finally block, each attempt logs twice (in except and finally) + log_entries = [entry for entry in agent.log if entry.get("step_type") == "tool_generation_attempt"] + # First attempt: failed (logged in except and finally) + # Second attempt: succeeded (logged in finally) + assert len(log_entries) >= 2 + # Find first failed attempt + failed_entries = [e for e in log_entries if e["success"] is False and e["attempt"] == 1] + assert len(failed_entries) >= 1 + assert failed_entries[0]["llm_content"] == "invalid json" + # Find successful attempt + success_entries = [e for e in log_entries if e["success"] is True and e["attempt"] == 2] + assert len(success_entries) >= 1 + assert success_entries[0]["llm_content"] == mock_completion_valid.content + + @pytest.mark.asyncio + async def test_prepare_tools(self): + """Test _prepare_tools method.""" + agent = create_test_agent( + IronAgent, + toolkit=[ReasoningTool, WebSearchTool], + ) + + # Call _prepare_tools + tool_selector_type = await agent._prepare_tools() + + # Check that it returns a type that is a subclass of ToolNameSelectorStub + assert isinstance(tool_selector_type, type) + # Check that it's a dynamically created model (has __name__) + assert hasattr(tool_selector_type, "__name__") + # Check that it can be instantiated with required ReasoningTool fields + instance = tool_selector_type( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + function_name_choice="reasoningtool", + ) + assert hasattr(instance, "function_name_choice") + + @pytest.mark.asyncio + async def test_reasoning_phase(self): + """Test _reasoning_phase.""" + agent = create_test_agent( + IronAgent, + toolkit=[ReasoningTool, WebSearchTool], + ) + + # Mock _generate_tool + # The result should be a ReasoningTool with tool_name field + from pydantic import create_model + + # Create a mock reasoning tool with function_name_choice field + mock_reasoning_class = create_model( + "MockReasoning", + __base__=ReasoningTool, + function_name_choice=(str, "websearchtool"), + ) + mock_reasoning = mock_reasoning_class( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + function_name_choice="websearchtool", + ) + + agent._generate_tool = AsyncMock(return_value=mock_reasoning) + + # Call reasoning phase + result = await agent._reasoning_phase() + + assert isinstance(result, ReasoningTool) + assert result.reasoning_steps == ["step1", "step2"] + assert hasattr(result, "function_name_choice") + agent._generate_tool.assert_called_once() + + @pytest.mark.asyncio + async def test_select_action_phase(self): + """Test _select_action_phase.""" + agent = create_test_agent( + IronAgent, + toolkit=[ReasoningTool, WebSearchTool], + ) + + # Mock reasoning with function_name_choice field (already selected in reasoning phase) + from pydantic import create_model + + mock_reasoning_class = create_model( + "MockReasoning", + __base__=ReasoningTool, + function_name_choice=(str, "websearchtool"), + ) + reasoning = mock_reasoning_class( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["search"], + task_completed=False, + function_name_choice="websearchtool", + ) + + # Mock tool instance generation + mock_tool = WebSearchTool(reasoning="test reasoning", query="test query") + agent._generate_tool = AsyncMock(return_value=mock_tool) + + # Call select action phase + result = await agent._select_action_phase(reasoning) + + assert isinstance(result, WebSearchTool) + assert result.query == "test query" + # Only one call - tool parameter generation (tool_name already selected in reasoning) + agent._generate_tool.assert_called_once() + + @pytest.mark.asyncio + async def test_action_phase(self): + """Test _action_phase.""" + + agent = create_test_agent( + IronAgent, + toolkit=[WebSearchTool], + ) + + # Mock tool execution - need to patch __call__ method + tool = WebSearchTool(reasoning="test reasoning", query="test query") + + # Patch WebSearchTool.__call__ to avoid TavilySearchService initialization + with patch( + "sgr_agent_core.tools.web_search_tool.WebSearchTool.__call__", + new_callable=AsyncMock, + return_value="Search results", + ) as mock_call: + # Call action phase + result = await agent._action_phase(tool) + + assert result == "Search results" + mock_call.assert_called_once_with(agent._context, agent.config) + + @pytest.mark.asyncio + async def test_log_tool_generation_attempt(self): + """Test _log_tool_instantiator logging.""" + from sgr_agent_core.services.tool_instantiator import ToolInstantiator + + agent = create_test_agent( + IronAgent, + toolkit=[ReasoningTool], + ) + + # Test successful attempt + instantiator = ToolInstantiator(ReasoningTool) + instantiator.input_content = ( + '{"reasoning_steps": ["step1", "step2"], "current_situation": "test", ' + '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], ' + '"task_completed": false}' + ) + instantiator.instance = ReasoningTool( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + ) + + agent._log_tool_instantiator( + instantiator=instantiator, + attempt=1, + max_retries=5, + ) + + # Check log entry + assert len(agent.log) == 1 + log_entry = agent.log[0] + assert log_entry["step_type"] == "tool_generation_attempt" + assert log_entry["tool_class"] == "ReasoningTool" + assert log_entry["attempt"] == 1 + assert log_entry["max_retries"] == 5 + assert log_entry["success"] is True + assert instantiator.instance is not None + assert log_entry["llm_content"] == instantiator.input_content + assert log_entry["errors"] == [] + + # Test failed attempt + instantiator2 = ToolInstantiator(ReasoningTool) + instantiator2.input_content = "invalid json" + instantiator2.errors = ["JSON decode error", "Validation error"] + + agent._log_tool_instantiator( + instantiator=instantiator2, + attempt=2, + max_retries=5, + ) + + # Check second log entry + assert len(agent.log) == 2 + log_entry = agent.log[1] + assert log_entry["success"] is False + assert log_entry["llm_content"] == "invalid json" + assert log_entry["errors"] == ["JSON decode error", "Validation error"] + assert instantiator2.instance is None diff --git a/tests/test_next_step_tool.py b/tests/test_next_step_tool.py new file mode 100644 index 00000000..bd63429f --- /dev/null +++ b/tests/test_next_step_tool.py @@ -0,0 +1,92 @@ +"""Tests for NextStepToolsBuilder.""" + +import pytest +from pydantic import ValidationError + +from sgr_agent_core.next_step_tool import NextStepToolsBuilder +from sgr_agent_core.tools import ReasoningTool, WebSearchTool + + +class TestNextStepToolsBuilder: + """Tests for NextStepToolsBuilder.""" + + def test_build_ToolNameSelector_single_tool(self): + """Test building ToolNameSelector with single tool.""" + tools_list = [ReasoningTool] + selector_model = NextStepToolsBuilder.build_NextStepToolSelector(tools_list) + + # Create instance with valid tool name and all required ReasoningTool fields + instance = selector_model( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + function_name_choice="reasoningtool", + ) + assert instance.function_name_choice == "reasoningtool" + + # Try invalid tool name - should raise ValidationError + with pytest.raises(ValidationError): + selector_model( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + tool_name="invalid_tool", + ) + + def test_build_ToolNameSelector_multiple_tools(self): + """Test building ToolNameSelector with multiple tools.""" + tools_list = [ReasoningTool, WebSearchTool] + selector_model = NextStepToolsBuilder.build_NextStepToolSelector(tools_list) + + # Create instance with valid tool names and all required ReasoningTool fields + instance1 = selector_model( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + function_name_choice="reasoningtool", + ) + assert instance1.function_name_choice == "reasoningtool" + + instance2 = selector_model( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + function_name_choice="websearchtool", + ) + assert instance2.function_name_choice == "websearchtool" + + # Try invalid tool name - should raise ValidationError + with pytest.raises(ValidationError): + selector_model( + reasoning_steps=["step1", "step2"], + current_situation="test", + plan_status="ok", + enough_data=False, + remaining_steps=["next"], + task_completed=False, + tool_name="invalid_tool", + ) + + def test_build_ToolNameSelector_includes_descriptions(self): + """Test that ToolNameSelector has proper field description.""" + tools_list = [ReasoningTool, WebSearchTool] + selector_model = NextStepToolsBuilder.build_NextStepToolSelector(tools_list) + + # Check that field description exists and is meaningful + field_info = selector_model.model_fields["function_name_choice"] + description = field_info.description + + assert description is not None + assert len(description) > 0 diff --git a/tests/test_tool_instantiator.py b/tests/test_tool_instantiator.py new file mode 100644 index 00000000..fbc3eca3 --- /dev/null +++ b/tests/test_tool_instantiator.py @@ -0,0 +1,761 @@ +"""Tests for ToolInstantiator service.""" + +import json +from json import JSONDecodeError + +import pytest + +from sgr_agent_core.services.tool_instantiator import SchemaSimplifier, ToolInstantiator +from sgr_agent_core.tools import ReasoningTool, WebSearchTool + + +class TestToolInstantiator: + """Test suite for ToolInstantiator.""" + + def test_initialization(self): + """Test ToolInstantiator initialization.""" + instantiator = ToolInstantiator(ReasoningTool) + + assert instantiator.tool_class == ReasoningTool + assert instantiator.errors == [] + assert instantiator.instance is None + assert instantiator.input_content == "" + + def test_generate_format_prompt_with_errors(self): + """Test generate_format_prompt with errors included.""" + instantiator = ToolInstantiator(ReasoningTool) + instantiator.errors = ["Error 1", "Error 2"] + instantiator.input_content = "invalid json" + + prompt = instantiator.generate_format_prompt(include_errors=True) + + assert "PREVIOUS FILLING ITERATION ERRORS" in prompt + assert "Error 1" in prompt + assert "Error 2" in prompt + assert "invalid json" in prompt + + def test_generate_format_prompt_without_errors_when_errors_exist(self): + """Test generate_format_prompt with include_errors=False when errors + exist.""" + instantiator = ToolInstantiator(ReasoningTool) + instantiator.errors = ["Error 1"] + instantiator.input_content = "invalid json" + + prompt = instantiator.generate_format_prompt(include_errors=False) + + assert "PREVIOUS FILLING ITERATION ERRORS" not in prompt + assert "Error 1" not in prompt + + def test_build_model_success(self): + """Test build_model with valid JSON content.""" + instantiator = ToolInstantiator(ReasoningTool) + content = json.dumps( + { + "reasoning_steps": ["step1", "step2"], + "current_situation": "test", + "plan_status": "ok", + "enough_data": False, + "remaining_steps": ["next"], + "task_completed": False, + } + ) + + result = instantiator.build_model(content) + + assert isinstance(result, ReasoningTool) + assert result == instantiator.instance + assert instantiator.instance is not None + assert instantiator.input_content == content + assert len(instantiator.errors) == 0 + assert result.reasoning_steps == ["step1", "step2"] + assert result.current_situation == "test" + + def test_build_model_with_whitespace(self): + """Test build_model handles whitespace correctly.""" + instantiator = ToolInstantiator(ReasoningTool) + content = ( + ' \n{"reasoning_steps": ["step1", "step2"], "current_situation": "test", ' + '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], ' + '"task_completed": false}\n ' + ) + + result = instantiator.build_model(content) + + assert isinstance(result, ReasoningTool) + assert instantiator.instance is not None + + def test_build_model_empty_content(self): + """Test build_model with empty content.""" + instantiator = ToolInstantiator(ReasoningTool) + + with pytest.raises(ValueError, match="No content provided"): + instantiator.build_model("") + + assert len(instantiator.errors) == 1 + assert "No content provided" in instantiator.errors + assert instantiator.instance is None + + def test_build_model_invalid_json(self): + """Test build_model with invalid JSON.""" + instantiator = ToolInstantiator(ReasoningTool) + content = "invalid json content" + + with pytest.raises(ValueError, match="Failed to build model"): + instantiator.build_model(content) + + assert len(instantiator.errors) > 0 + # Check for JSON parse error message (new format with context) + assert any("JSON parse error" in err or "Failed to parse JSON" in err for err in instantiator.errors) + assert instantiator.input_content == content + assert instantiator.instance is None + + def test_build_model_validation_error(self): + """Test build_model with Pydantic validation error.""" + instantiator = ToolInstantiator(ReasoningTool) + # Missing required fields + content = json.dumps({"reasoning_steps": ["step1"]}) + + with pytest.raises(ValueError, match="Failed to build model"): + instantiator.build_model(content) + + assert len(instantiator.errors) > 0 + assert any("pydantic validation error" in err for err in instantiator.errors) + assert instantiator.input_content == content + assert instantiator.instance is None + + def test_build_model_clears_errors_on_new_attempt(self): + """Test that build_model clears errors before new attempt.""" + instantiator = ToolInstantiator(ReasoningTool) + instantiator.errors = ["Previous error"] + + # First attempt fails + try: + instantiator.build_model("invalid") + except ValueError: + pass + + assert len(instantiator.errors) > 0 + + # Second attempt succeeds - errors should be cleared + valid_content = json.dumps( + { + "reasoning_steps": ["step1", "step2"], + "current_situation": "test", + "plan_status": "ok", + "enough_data": False, + "remaining_steps": ["next"], + "task_completed": False, + } + ) + result = instantiator.build_model(valid_content) + + assert isinstance(result, ReasoningTool) + assert instantiator.instance is not None + + def test_build_model_with_web_search_tool(self): + """Test build_model with different tool class.""" + instantiator = ToolInstantiator(WebSearchTool) + content = json.dumps({"reasoning": "test reasoning", "query": "test query"}) + + result = instantiator.build_model(content) + + assert isinstance(result, WebSearchTool) + assert result.reasoning == "test reasoning" + assert result.query == "test query" + assert instantiator.instance == result + + def test_generate_format_prompt_includes_schema(self): + """Test that generate_format_prompt includes simplified schema.""" + instantiator = ToolInstantiator(ReasoningTool) + prompt = instantiator.generate_format_prompt() + + # Check that simplified schema format is present + assert "" in prompt + assert "" in prompt + # Check for simplified format markers + assert "reasoning_steps (required" in prompt + assert "list[string]" in prompt + assert "enough_data (optional" in prompt + + def test_errors_accumulation(self): + """Test that errors accumulate across multiple failed attempts.""" + instantiator = ToolInstantiator(ReasoningTool) + + # First attempt - invalid JSON + try: + instantiator.build_model("invalid json 1") + except ValueError: + pass + + # Second attempt - invalid JSON + try: + instantiator.build_model("invalid json 2") + except ValueError: + pass + + # Note: build_model clears errors at start, so each attempt starts fresh + # But we can check that errors are added during each attempt + assert len(instantiator.errors) > 0 + + def test_clearing_context_extracts_json(self): + """Test _clearing_context extracts JSON from mixed content.""" + instantiator = ToolInstantiator(ReasoningTool) + + # Content with text before and after JSON + content = 'Some text before {"reasoning_steps": ["step1"], "current_situation": "test"} and after' + result = instantiator._clearing_context(content) + + assert result.startswith("{") + assert result.endswith("}") + assert "reasoning_steps" in result + assert "Some text before" not in result + assert "and after" not in result + + def test_clearing_context_no_braces(self): + """Test _clearing_context returns original content if no braces + found.""" + instantiator = ToolInstantiator(ReasoningTool) + content = "no json here" + + result = instantiator._clearing_context(content) + + assert result == content + + def test_clearing_context_only_opening_brace(self): + """Test _clearing_context handles case with only opening brace.""" + instantiator = ToolInstantiator(ReasoningTool) + content = "text { incomplete json" + + result = instantiator._clearing_context(content) + + assert result == content + + def test_build_model_with_text_around_json(self): + """Test build_model extracts JSON from text with surrounding + content.""" + instantiator = ToolInstantiator(ReasoningTool) + json_data = { + "reasoning_steps": ["step1", "step2"], + "current_situation": "test", + "plan_status": "ok", + "enough_data": False, + "remaining_steps": ["next"], + "task_completed": False, + } + content = f"Here is the JSON: {json.dumps(json_data)} and some text after" + + result = instantiator.build_model(content) + + assert isinstance(result, ReasoningTool) + assert result.reasoning_steps == ["step1", "step2"] + + def test_format_json_error_with_context(self): + """Test _format_json_error includes context around error position.""" + instantiator = ToolInstantiator(ReasoningTool) + # Create invalid JSON with extra data + content = '{"field": "value"} extra data here' + try: + json.loads(content) + except JSONDecodeError as e: + error_msg = instantiator._format_json_error(e, content) + + assert "JSON parse error" in error_msg + assert "position" in error_msg + assert "Context:" in error_msg + assert "extra data" in error_msg or "value" in error_msg + + def test_format_json_error_without_position(self): + """Test _format_json_error handles error without position attribute.""" + instantiator = ToolInstantiator(ReasoningTool) + + # Create error-like object without pos attribute + class ErrorWithoutPos: + def __init__(self): + self.msg = "test error" + + error = ErrorWithoutPos() + error_msg = instantiator._format_json_error(error, "test content") + + # Should fall back to generic error message when pos is missing + assert "Failed to parse JSON" in error_msg + + def test_format_json_error_at_start(self): + """Test _format_json_error handles error at start of content.""" + instantiator = ToolInstantiator(ReasoningTool) + content = 'invalid {"field": "value"}' + try: + json.loads(content) + except JSONDecodeError as e: + error_msg = instantiator._format_json_error(e, content) + assert "JSON parse error" in error_msg + assert "position" in error_msg + + def test_format_json_error_at_end(self): + """Test _format_json_error handles error at end of content.""" + instantiator = ToolInstantiator(ReasoningTool) + content = '{"field": "value"} invalid' + try: + json.loads(content) + except JSONDecodeError as e: + error_msg = instantiator._format_json_error(e, content) + assert "JSON parse error" in error_msg + assert "position" in error_msg + + def test_generate_format_prompt_includes_tool_info(self): + """Test that generate_format_prompt includes tool name.""" + instantiator = ToolInstantiator(ReasoningTool) + prompt = instantiator.generate_format_prompt() + + assert "" in prompt + assert "" in prompt + assert ReasoningTool.tool_name in prompt + + def test_generate_format_prompt_includes_format_template(self): + """Test that generate_format_prompt includes important format + sections.""" + instantiator = ToolInstantiator(ReasoningTool) + prompt = instantiator.generate_format_prompt() + + # Check that key sections are present (not checking exact strings to avoid hardcoding) + assert len(prompt) > 100 # Should be substantial + assert "" in prompt # Schema section is critical + assert "" in prompt # Tool info is critical + # Check that it contains instructions (not just checking exact strings) + assert any(keyword in prompt.lower() for keyword in ["json", "schema", "format", "required"]) + + def test_clearing_context_multiple_braces(self): + """Test _clearing_context extracts JSON when multiple braces exist.""" + instantiator = ToolInstantiator(ReasoningTool) + content = 'text { {"reasoning_steps": ["step1"], "current_situation": "test"} } more text' + result = instantiator._clearing_context(content) + + assert result.startswith("{") + assert result.endswith("}") + assert "reasoning_steps" in result + + def test_clearing_context_nested_objects(self): + """Test _clearing_context handles nested JSON objects.""" + instantiator = ToolInstantiator(ReasoningTool) + nested_json = {"outer": {"inner": "value"}} + content = f"Text {json.dumps(nested_json)} more text" + result = instantiator._clearing_context(content) + + assert result.startswith("{") + assert result.endswith("}") + assert "outer" in result + assert "inner" in result + + def test_build_model_clears_content_attribute(self): + """Test that build_model clears and sets content attribute + correctly.""" + instantiator = ToolInstantiator(ReasoningTool) + instantiator.content = "previous content" + + valid_content = json.dumps( + { + "reasoning_steps": ["step1", "step2"], + "current_situation": "test", + "plan_status": "ok", + "enough_data": False, + "remaining_steps": ["next"], + "task_completed": False, + } + ) + instantiator.build_model(valid_content) + + # Content should be set to cleaned content (same as input in this case) + assert instantiator.content == valid_content + assert instantiator.content != "previous content" + + def test_build_model_with_markdown_code_block(self): + """Test build_model extracts JSON from markdown code block.""" + instantiator = ToolInstantiator(ReasoningTool) + json_data = { + "reasoning_steps": ["step1", "step2"], + "current_situation": "test", + "plan_status": "ok", + "enough_data": False, + "remaining_steps": ["next"], + "task_completed": False, + } + content = f"```json\n{json.dumps(json_data)}\n```" + + result = instantiator.build_model(content) + + assert isinstance(result, ReasoningTool) + assert result.reasoning_steps == ["step1", "step2"] + + +class TestSchemaSimplifier: + """Test suite for SchemaSimplifier.""" + + def test_simplify_simple_string_field(self): + """Test simplifying schema with simple string field.""" + schema = { + "properties": { + "name": { + "type": "string", + "description": "User name", + } + }, + "required": ["name"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "name (required, string): User name" in result + + def test_simplify_optional_field(self): + """Test simplifying schema with optional field.""" + schema = { + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "name (required" in result + assert "age (optional" in result + + def test_simplify_array_field(self): + """Test simplifying schema with array field.""" + schema = { + "properties": { + "tags": { + "type": "array", + "items": {"type": "string"}, + "description": "List of tags", + } + }, + "required": ["tags"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "tags (required" in result + assert "list[string]" in result + + def test_simplify_array_with_constraints(self): + """Test simplifying schema with array constraints.""" + schema = { + "properties": { + "items": { + "type": "array", + "items": {"type": "string"}, + "minItems": 1, + "maxItems": 5, + "description": "List of items", + } + }, + "required": ["items"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "items (required" in result + assert "1-5 items" in result + + def test_simplify_string_with_length_constraints(self): + """Test simplifying schema with string length constraints.""" + schema = { + "properties": { + "text": { + "type": "string", + "minLength": 10, + "maxLength": 100, + "description": "Text field", + } + }, + "required": ["text"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "text (required" in result + assert "length: 10-100" in result + + def test_simplify_integer_with_range_constraints(self): + """Test simplifying schema with integer range constraints.""" + schema = { + "properties": { + "age": { + "type": "integer", + "minimum": 18, + "maximum": 120, + "description": "User age", + } + }, + "required": ["age"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "age (required" in result + assert "range: 18-120" in result + + def test_simplify_field_with_default(self): + """Test simplifying schema with default value.""" + schema = { + "properties": { + "status": { + "type": "string", + "default": "active", + "description": "Status", + } + }, + "required": [], + } + + result = SchemaSimplifier.simplify(schema) + + assert "status (optional" in result + assert 'default: "active"' in result + + def test_simplify_enum_field(self): + """Test simplifying schema with enum (Literal).""" + schema = { + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"], + "description": "Status", + } + }, + "required": ["status"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "status (required" in result + assert "Literal[" in result + assert '"active"' in result or "'active'" in result + + def test_simplify_const_field(self): + """Test simplifying schema with const value.""" + schema = { + "properties": { + "type": { + "type": "string", + "const": "user", + "description": "Type", + } + }, + "required": ["type"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "type (required" in result + assert "const:" in result + + def test_simplify_nested_object(self): + """Test simplifying schema with nested object.""" + schema = { + "properties": { + "address": { + "type": "object", + "properties": { + "street": {"type": "string", "description": "Street"}, + "city": {"type": "string", "description": "City"}, + }, + "required": ["street"], + "description": "Address", + } + }, + "required": ["address"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "address (required" in result + assert "street (required" in result + assert "city (optional" in result + + def test_simplify_anyof_with_ref(self): + """Test simplifying schema with anyOf containing $ref.""" + schema = { + "properties": { + "tool": { + "anyOf": [ + {"$ref": "#/$defs/ToolA"}, + {"$ref": "#/$defs/ToolB"}, + ], + "description": "Select tool", + } + }, + "required": ["tool"], + "$defs": { + "ToolA": { + "title": "ToolA", + "properties": { + "param1": {"type": "string", "description": "Param 1"}, + }, + "required": ["param1"], + }, + "ToolB": { + "title": "ToolB", + "properties": { + "param2": {"type": "integer", "description": "Param 2"}, + }, + "required": ["param2"], + }, + }, + } + + result = SchemaSimplifier.simplify(schema) + + assert "tool (required" in result + assert "Literal[" in result + assert "ToolA" in result + assert "ToolB" in result + assert "Variant: ToolA" in result + assert "param1 (required" in result + assert "Variant: ToolB" in result + assert "param2 (required" in result + + def test_simplify_anyof_with_const(self): + """Test simplifying schema with anyOf containing const values.""" + schema = { + "properties": { + "status": { + "anyOf": [ + {"const": "active"}, + {"const": "inactive"}, + ], + "description": "Status", + } + }, + "required": ["status"], + } + + result = SchemaSimplifier.simplify(schema) + + assert "status (required" in result + assert "const(" in result + assert "active" in result + assert "inactive" in result + + def test_simplify_empty_properties(self): + """Test simplifying schema with no properties.""" + schema = {"properties": {}, "required": []} + + result = SchemaSimplifier.simplify(schema) + + assert result == "" + + def test_simplify_no_properties_key(self): + """Test simplifying schema without properties key.""" + schema = {"required": []} + + result = SchemaSimplifier.simplify(schema) + + assert result == "" + + def test_simplify_fields_sorted_required_first(self): + """Test that required fields come before optional fields.""" + schema = { + "properties": { + "optional_field": {"type": "string", "description": "Optional"}, + "required_field": {"type": "string", "description": "Required"}, + }, + "required": ["required_field"], + } + + result = SchemaSimplifier.simplify(schema) + lines = result.split("\n") + + required_index = next(i for i, line in enumerate(lines) if "required_field" in line) + optional_index = next(i for i, line in enumerate(lines) if "optional_field" in line) + + assert required_index < optional_index + + def test_simplify_with_indent(self): + """Test simplifying schema with indentation.""" + schema = { + "properties": { + "nested": { + "type": "object", + "properties": { + "field": {"type": "string", "description": "Field"}, + }, + "required": ["field"], + "description": "Nested", + } + }, + "required": ["nested"], + } + + result = SchemaSimplifier.simplify(schema, indent=1) + + lines = result.split("\n") + nested_line = next(line for line in lines if "nested (required" in line) + nested_field_line = next(line for line in lines if "field" in line and "Field" in line) + # Check that nested field has more indentation than parent + assert len(nested_field_line) - len(nested_field_line.lstrip()) > len(nested_line) - len(nested_line.lstrip()) + + def test_extract_type_unknown(self): + """Test _extract_type returns 'unknown' for unknown type.""" + schema = {} + + result = SchemaSimplifier._extract_type(schema) + + assert result == "unknown" + + def test_extract_constraints_empty(self): + """Test _extract_constraints returns empty list when no constraints.""" + schema = {"type": "string"} + + result = SchemaSimplifier._extract_constraints(schema) + + assert result == [] + + def test_extract_constraints_minimum_only(self): + """Test _extract_constraints with only minimum.""" + schema = {"type": "integer", "minimum": 0} + + result = SchemaSimplifier._extract_constraints(schema) + + assert "min: 0" in result + + def test_extract_constraints_maximum_only(self): + """Test _extract_constraints with only maximum.""" + schema = {"type": "integer", "maximum": 100} + + result = SchemaSimplifier._extract_constraints(schema) + + assert "max: 100" in result + + def test_extract_constraints_min_length_only(self): + """Test _extract_constraints with only minLength.""" + schema = {"type": "string", "minLength": 5} + + result = SchemaSimplifier._extract_constraints(schema) + + assert "min length: 5" in result + + def test_extract_constraints_max_length_only(self): + """Test _extract_constraints with only maxLength.""" + schema = {"type": "string", "maxLength": 100} + + result = SchemaSimplifier._extract_constraints(schema) + + assert "max length: 100" in result + + def test_extract_constraints_min_items_only(self): + """Test _extract_constraints with only minItems.""" + schema = {"type": "array", "minItems": 1} + + result = SchemaSimplifier._extract_constraints(schema) + + assert "min 1 items" in result + + def test_extract_constraints_max_items_only(self): + """Test _extract_constraints with only maxItems.""" + schema = {"type": "array", "maxItems": 10} + + result = SchemaSimplifier._extract_constraints(schema) + + assert "max 10 items" in result