diff --git a/sgr_agent_core/agents/__init__.py b/sgr_agent_core/agents/__init__.py
index ef00bc1b..3c8e9bfd 100644
--- a/sgr_agent_core/agents/__init__.py
+++ b/sgr_agent_core/agents/__init__.py
@@ -1,10 +1,12 @@
"""Agents module for SGR Agent Core."""
+from sgr_agent_core.agents.iron_agent import IronAgent
from sgr_agent_core.agents.sgr_agent import SGRAgent
from sgr_agent_core.agents.sgr_tool_calling_agent import SGRToolCallingAgent
from sgr_agent_core.agents.tool_calling_agent import ToolCallingAgent
__all__ = [
+ "IronAgent",
"SGRAgent",
"SGRToolCallingAgent",
"ToolCallingAgent",
diff --git a/sgr_agent_core/agents/iron_agent.py b/sgr_agent_core/agents/iron_agent.py
new file mode 100644
index 00000000..4270ec1f
--- /dev/null
+++ b/sgr_agent_core/agents/iron_agent.py
@@ -0,0 +1,229 @@
+"""Fuzzy Agent that doesn't rely on tool calling or structured output."""
+
+from datetime import datetime
+from typing import Type
+
+from openai import AsyncOpenAI
+from pydantic import BaseModel
+
+from sgr_agent_core.agent_config import AgentConfig
+from sgr_agent_core.base_agent import BaseAgent
+from sgr_agent_core.next_step_tool import NextStepToolsBuilder
+from sgr_agent_core.services.registry import ToolRegistry
+from sgr_agent_core.services.tool_instantiator import ToolInstantiator
+from sgr_agent_core.tools import BaseTool, ReasoningTool, ToolNameSelectorStub
+
+
+class IronAgent(BaseAgent):
+ """Agent that uses flexible parsing of LLM text responses instead of tool
+ calling.
+
+ This agent doesn't rely on:
+ - Tool calling (function calling)
+ - Structured output (response_format)
+
+ Instead, it parses natural language responses from LLM to determine
+ which tool to use and with what parameters using ToolInstantiator.
+ """
+
+ name: str = "iron_agent"
+
+ def __init__(
+ self,
+ task_messages: list,
+ openai_client: AsyncOpenAI,
+ agent_config: AgentConfig,
+ toolkit: list[Type[BaseTool]],
+ def_name: str | None = None,
+ **kwargs: dict,
+ ):
+ super().__init__(
+ task_messages=task_messages,
+ openai_client=openai_client,
+ agent_config=agent_config,
+ toolkit=toolkit,
+ def_name=def_name,
+ **kwargs,
+ )
+
+ def _log_tool_instantiator(
+ self,
+ instantiator: ToolInstantiator,
+ attempt: int,
+ max_retries: int,
+ ):
+ """Log tool generation attempt by LLM using data from ToolInstantiator.
+
+ Args:
+ instantiator: ToolInstantiator instance with attempt data
+ attempt: Current attempt number (1-based)
+ max_retries: Maximum number of retry attempts
+ """
+ success = instantiator.instance is not None
+ errors_formatted = instantiator.input_content + "\n" + "\n".join(instantiator.errors)
+ self.logger.info(
+ f"""
+###############################################
+TOOL GENERATION DEBUG
+ {"✅" if success else "❌"} ATTEMPT {attempt}/{max_retries} {"SUCCESS" if success else "FAILED"}:
+
+ Tool: {instantiator.tool_class.tool_name} - Class: {instantiator.tool_class.__name__}
+
+{errors_formatted if not success else ""}
+###############################################"""
+ )
+
+ self.log.append(
+ {
+ "step_number": self._context.iteration,
+ "timestamp": datetime.now().isoformat(),
+ "step_type": "tool_generation_attempt",
+ "tool_class": instantiator.tool_class.__name__,
+ "attempt": attempt,
+ "max_retries": max_retries,
+ "success": success,
+ "llm_content": instantiator.input_content,
+ "errors": instantiator.errors.copy(),
+ }
+ )
+
+ async def _generate_tool(
+ self,
+ tool_class: Type[BaseTool],
+ messages: list[dict],
+ max_retries: int = 5,
+ ) -> BaseTool | BaseModel:
+ """Generate tool instance from LLM response using ToolInstantiator.
+
+ Universal method for calling LLM with parsing through ToolInstantiator.
+ Handles retries with error accumulation.
+
+ Args:
+ tool_class: Tool class or model class to instantiate
+ messages: Context messages for LLM
+ max_retries: Maximum number of retry attempts
+
+ Returns:
+ Instance of tool_class
+
+ Raises:
+ ValueError: If parsing fails after max_retries attempts
+ """
+ instantiator = ToolInstantiator(tool_class)
+
+ for attempt in range(max_retries):
+ async with self.openai_client.chat.completions.stream(
+ messages=messages + [{"role": "user", "content": instantiator.generate_format_prompt()}],
+ **self.config.llm.to_openai_client_kwargs(),
+ ) as stream:
+ async for event in stream:
+ if event.type == "chunk":
+ self.streaming_generator.add_chunk(event.chunk)
+
+ completion = await stream.get_final_completion()
+ content = completion.choices[0].message.content
+ try:
+ tool_instance = instantiator.build_model(content)
+ return tool_instance
+ except ValueError:
+ continue
+ finally:
+ self._log_tool_instantiator(
+ instantiator=instantiator,
+ attempt=attempt + 1,
+ max_retries=max_retries,
+ )
+
+ raise ValueError(
+ f"Failed to parse {tool_class.__name__} after {max_retries} attempts. "
+ f"Try to simplify tool schema or provide more detailed instructions."
+ )
+
+ async def _prepare_tools(self) -> Type[ToolNameSelectorStub]:
+ """Prepare available tools for the current agent state and progress."""
+ if self._context.iteration >= self.config.execution.max_iterations:
+ raise RuntimeError("Max iterations reached")
+ return NextStepToolsBuilder.build_NextStepToolSelector(self.toolkit)
+
+ async def _reasoning_phase(self) -> ReasoningTool:
+ """Call LLM to get ReasoningTool with selected tool name."""
+ messages = await self._prepare_context()
+
+ tool_selector_model = await self._prepare_tools()
+ reasoning = await self._generate_tool(tool_selector_model, messages)
+
+ if not isinstance(reasoning, ReasoningTool):
+ raise ValueError("Expected ReasoningTool instance")
+
+ # Log reasoning
+ self._log_reasoning(reasoning)
+
+ # Add to streaming
+ self.streaming_generator.add_tool_call(
+ f"{self._context.iteration}-reasoning",
+ reasoning.tool_name,
+ reasoning.model_dump_json(exclude={"function_name_choice"}),
+ )
+
+ return reasoning
+
+ async def _select_action_phase(self, reasoning: ReasoningTool) -> BaseTool:
+ """Select tool based on reasoning phase result."""
+ messages = await self._prepare_context()
+
+ tool_name = reasoning.function_name_choice # type: ignore
+
+ # Find tool class by name
+ tool_class: Type[BaseTool] | None = None
+
+ # Try ToolRegistry first
+ tool_class = ToolRegistry.get(tool_name)
+
+ # If not found, search in toolkit
+ if tool_class is None:
+ for tool in self.toolkit:
+ if tool.tool_name == tool_name:
+ tool_class = tool
+ break
+
+ if tool_class is None:
+ raise ValueError(f"Tool '{tool_name}' not found in toolkit")
+
+ # Generate tool parameters
+ tool = await self._generate_tool(tool_class, messages)
+
+ if not isinstance(tool, BaseTool):
+ raise ValueError("Selected tool is not a valid BaseTool instance")
+
+ # Add to conversation
+ self.conversation.append(
+ {
+ "role": "assistant",
+ "content": reasoning.remaining_steps[0] if reasoning.remaining_steps else "Completing",
+ "tool_calls": [
+ {
+ "type": "function",
+ "id": f"{self._context.iteration}-action",
+ "function": {
+ "name": tool.tool_name,
+ "arguments": tool.model_dump_json(),
+ },
+ }
+ ],
+ }
+ )
+ self.streaming_generator.add_tool_call(
+ f"{self._context.iteration}-action", tool.tool_name, tool.model_dump_json()
+ )
+
+ return tool
+
+ async def _action_phase(self, tool: BaseTool) -> str:
+ """Execute selected tool."""
+ result = await tool(self._context, self.config)
+ self.conversation.append(
+ {"role": "tool", "content": result, "tool_call_id": f"{self._context.iteration}-action"}
+ )
+ self.streaming_generator.add_chunk_from_str(f"{result}\n")
+ self._log_tool_execution(tool, result)
+ return result
diff --git a/sgr_agent_core/next_step_tool.py b/sgr_agent_core/next_step_tool.py
index af855a9d..3790b8f3 100644
--- a/sgr_agent_core/next_step_tool.py
+++ b/sgr_agent_core/next_step_tool.py
@@ -23,6 +23,17 @@ class NextStepToolStub(ReasoningTool, ABC):
function: T = Field(description="Select the appropriate tool for the next step")
+class ToolNameSelectorStub(ReasoningTool, ABC):
+ """Stub class for tool name selection that inherits from ReasoningTool.
+
+ Used by IronAgent to select tool name as part of reasoning phase.
+ (!) Stub class for correct autocomplete. Use
+ NextStepToolsBuilder.build_NextStepToolSelector
+ """
+
+ function_name_choice: str = Field(description="Select the name of the tool to use")
+
+
class DiscriminantToolMixin(BaseModel):
tool_name_discriminator: str = Field(..., description="Tool name discriminator")
@@ -60,8 +71,35 @@ def _create_tool_types_union(cls, tools_list: list[Type[T]]) -> Type:
@classmethod
def build_NextStepTools(cls, tools_list: list[Type[T]]) -> Type[NextStepToolStub]: # noqa
+ """Build a model with all NextStepTool args."""
return create_model(
"NextStepTools",
__base__=NextStepToolStub,
- function=(cls._create_tool_types_union(tools_list), Field()),
+ function=(
+ cls._create_tool_types_union(tools_list),
+ Field(description="Select and fill parameters of the appropriate tool for the next step"),
+ ),
+ )
+
+ @classmethod
+ def build_NextStepToolSelector(cls, tools_list: list[Type[T]]) -> Type[ToolNameSelectorStub]:
+ """Build a model for selecting tool name."""
+ # Extract tool names and descriptions
+ tool_names = [tool.tool_name for tool in tools_list]
+
+ if len(tool_names) == 1:
+ literal_type = Literal[tool_names[0]]
+ else:
+ # Create union of individual Literal types using operator.or_
+ # Literal["a"] | Literal["b"] is equivalent to Literal["a", "b"]
+ literal_types = [Literal[name] for name in tool_names]
+ literal_type = reduce(operator.or_, literal_types)
+
+ # Create model dynamically, inheriting from ToolNameSelectorStub (which inherits from ReasoningTool)
+ model_class = create_model(
+ "NextStepToolSelector",
+ __base__=ToolNameSelectorStub,
+ function_name_choice=(literal_type, Field(description="Choose the name for the best tool to use")),
)
+ model_class.tool_name = "nextsteptoolselector" # type: ignore
+ return model_class
diff --git a/sgr_agent_core/services/__init__.py b/sgr_agent_core/services/__init__.py
index 7ff92bf0..081b3680 100644
--- a/sgr_agent_core/services/__init__.py
+++ b/sgr_agent_core/services/__init__.py
@@ -4,6 +4,7 @@
from sgr_agent_core.services.prompt_loader import PromptLoader
from sgr_agent_core.services.registry import AgentRegistry, ToolRegistry
from sgr_agent_core.services.tavily_search import TavilySearchService
+from sgr_agent_core.services.tool_instantiator import ToolInstantiator
__all__ = [
"TavilySearchService",
@@ -11,4 +12,5 @@
"ToolRegistry",
"AgentRegistry",
"PromptLoader",
+ "ToolInstantiator",
]
diff --git a/sgr_agent_core/services/tool_instantiator.py b/sgr_agent_core/services/tool_instantiator.py
new file mode 100644
index 00000000..25bfb243
--- /dev/null
+++ b/sgr_agent_core/services/tool_instantiator.py
@@ -0,0 +1,404 @@
+import json
+from json import JSONDecodeError
+from typing import TYPE_CHECKING
+
+from pydantic import ValidationError
+
+if TYPE_CHECKING:
+ from sgr_agent_core.base_tool import BaseTool
+
+
+class SchemaSimplifier:
+ """Converts JSON schemas to structured text format."""
+
+ @classmethod
+ def simplify(cls, schema: dict, indent: int = 0) -> str:
+ """Convert JSON schema to structured text format (variant 3 - compact inline).
+
+ Format: - field_name (required/optional, тип, ограничения): описание
+
+ Args:
+ schema: JSON schema dictionary from model_json_schema()
+ indent: Current indentation level for nested structures
+
+ Returns:
+ Formatted text representation of schema
+ """
+ properties = schema.get("properties", {})
+ required = schema.get("required", [])
+ defs = schema.get("$defs", {})
+
+ if not properties:
+ return ""
+
+ required_fields = [(name, props) for name, props in properties.items() if name in required]
+ optional_fields = [(name, props) for name, props in properties.items() if name not in required]
+ sorted_fields = required_fields + optional_fields
+
+ result_lines = []
+ indent_str = " " * indent
+
+ for field_name, field_schema in sorted_fields:
+ is_required = field_name in required
+ req_text = "required" if is_required else "optional"
+
+ field_type = cls._extract_type(field_schema, defs)
+ constraints = cls._extract_constraints(field_schema)
+ description = field_schema.get("description", "")
+
+ # Format line: - field_name (required/optional, type, restrictions): description
+ type_constraints = ", ".join([field_type, *constraints])
+ line = f"{indent_str}- {field_name} ({req_text}, {type_constraints}): {description}"
+
+ result_lines.append(line)
+
+ # Handle anyOf with $ref to $defs - expand nested schemas
+ if "anyOf" in field_schema and defs:
+ for variant in field_schema["anyOf"]:
+ if "$ref" in variant:
+ ref_path = variant["$ref"]
+ if ref_path.startswith("#/$defs/"):
+ def_name = ref_path.replace("#/$defs/", "")
+ if def_name in defs:
+ def_schema = defs[def_name]
+ # Get schema name from title or def_name
+ schema_name = def_schema.get("title", def_name)
+ # Add schema name as header
+ nested_indent_str = " " * (indent + 1)
+ result_lines.append(f"{nested_indent_str}Variant: {schema_name}")
+ # Recursively simplify the schema from $defs
+ nested = cls.simplify(def_schema, indent + 2)
+ if nested:
+ result_lines.append(nested)
+
+ # Handle nested objects recursively
+ if field_schema.get("type") == "object" and "properties" in field_schema:
+ nested = cls.simplify(field_schema, indent + 1)
+ if nested:
+ result_lines.append(nested)
+
+ return "\n".join(result_lines)
+
+ @classmethod
+ def _extract_type(cls, field_schema: dict, defs: dict | None = None) -> str:
+ """Extract type string from field schema.
+
+ Args:
+ field_schema: Field schema dictionary
+ defs: Definitions dictionary for resolving $ref (optional)
+
+ Returns:
+ Type string representation
+ """
+ if defs is None:
+ defs = {}
+ # Handle const
+ if "const" in field_schema:
+ base_type = field_schema.get("type", "string")
+ const_value = field_schema["const"]
+ return f"{base_type} (const: {json.dumps(const_value)})"
+
+ # Handle enum (Literal)
+ if "enum" in field_schema:
+ enum_values = field_schema["enum"]
+ enum_str = ", ".join(json.dumps(v) for v in enum_values)
+ return f"Literal[{enum_str}]"
+
+ # Handle anyOf (Union)
+ if "anyOf" in field_schema:
+ types = []
+ const_values = []
+ for variant in field_schema["anyOf"]:
+ if "$ref" in variant:
+ # Resolve $ref
+ ref_path = variant["$ref"]
+ if ref_path.startswith("#/$defs/") and defs:
+ def_name = ref_path.replace("#/$defs/", "")
+ if def_name in defs:
+ # Use title or def_name from schema
+ def_schema = defs[def_name]
+ type_name = def_schema.get("title", def_name)
+ types.append(type_name)
+ else:
+ # Use def_name as-is if not in defs
+ types.append(def_name)
+ else:
+ types.append(ref_path)
+ elif "const" in variant:
+ # Collect const values
+ const_values.append(variant["const"])
+ else:
+ # Extract type from variant
+ variant_type = cls._extract_type(variant, defs)
+ types.append(variant_type)
+
+ # If all variants are const values, format as const(value1 OR value2 OR value3)
+ if const_values and not types:
+ const_str = " OR ".join(json.dumps(v) for v in const_values)
+ return f"const({const_str})"
+
+ # Wrap Union types in Literal format
+ if types:
+ types_str = ", ".join(types)
+ return f"Literal[{types_str}]"
+
+ return "unknown"
+
+ # Handle array
+ if field_schema.get("type") == "array":
+ items = field_schema.get("items", {})
+ if isinstance(items, dict):
+ element_type = cls._extract_type(items, defs)
+ return f"list[{element_type}]"
+ return "list"
+
+ # Handle object
+ if field_schema.get("type") == "object":
+ if "properties" in field_schema:
+ return "object"
+ return "object"
+
+ # Simple types
+ field_type = field_schema.get("type", "unknown")
+ return field_type
+
+ @classmethod
+ def _extract_constraints(cls, field_schema: dict) -> list[str]:
+ """Extract constraints string from field schema.
+
+ Args:
+ field_schema: Field schema dictionary
+
+ Returns:
+ Constraints string (empty if no constraints)
+ """
+ constraints = []
+
+ # Default value
+ if "default" in field_schema:
+ default_val = field_schema["default"]
+ constraints.append(f"default: {json.dumps(default_val)}")
+
+ # Numeric range
+ if "minimum" in field_schema and "maximum" in field_schema:
+ min_val = field_schema["minimum"]
+ max_val = field_schema["maximum"]
+ constraints.append(f"range: {min_val}-{max_val}")
+ elif "minimum" in field_schema:
+ constraints.append(f"min: {field_schema['minimum']}")
+ elif "maximum" in field_schema:
+ constraints.append(f"max: {field_schema['maximum']}")
+
+ # String length
+ if "minLength" in field_schema and "maxLength" in field_schema:
+ min_len = field_schema["minLength"]
+ max_len = field_schema["maxLength"]
+ constraints.append(f"length: {min_len}-{max_len}")
+ elif "minLength" in field_schema:
+ constraints.append(f"min length: {field_schema['minLength']}")
+ elif "maxLength" in field_schema:
+ constraints.append(f"max length: {field_schema['maxLength']}")
+
+ # Array items
+ if "minItems" in field_schema and "maxItems" in field_schema:
+ min_items = field_schema["minItems"]
+ max_items = field_schema["maxItems"]
+ constraints.append(f"{min_items}-{max_items} items")
+ elif "minItems" in field_schema:
+ constraints.append(f"min {field_schema['minItems']} items")
+ elif "maxItems" in field_schema:
+ constraints.append(f"max {field_schema['maxItems']} items")
+
+ return constraints
+
+
+# YESS im trying in SO inside a framework
+class ToolInstantiator:
+ """Structured Output like Service for formatting, parsing raw LLM context
+ and building tool instances."""
+
+ _FORMAT_PROMPT_TEMPLATE = (
+ "CRITICAL: Generate a JSON OBJECT with actual data values, NOT the schema definition.\n\n"
+ "================================================================================\n"
+ "HOW TO READ THE SCHEMA (TEXT FORMAT):\n"
+ "================================================================================\n\n"
+ " FORMAT: - field_name (required/optional, type, constraints): description\n\n"
+ " EXAMPLES:\n"
+ " - name (required, string, max length: 100): User's full name\n"
+ " - age (optional, integer, range: 18-120, default: 25): User's age\n"
+ " - tags (required, list[string], 1-5 items): List of tags\n"
+ ' - status (required, const("active" OR "inactive")): Account status\n'
+ " - tool (required, Literal[ToolA, ToolB, ToolC]): Select one tool\n"
+ " Variant: ToolA\n"
+ " - param1 (required, string): Parameter 1\n"
+ " - param2 (optional, integer): Parameter 2\n"
+ " Variant: ToolB\n"
+ " - param3 (required, boolean): Parameter 3\n\n"
+ " KEY CONCEPTS:\n"
+ " - required/optional: Whether the field must be provided\n"
+ " - type: Data type (string, integer, boolean, list[type], Literal[...], etc.)\n"
+ " - constraints: Limits like 'range: 1-10', 'max length: 300', '2-3 items', 'default: value'\n"
+ " - Literal[...]: Union type - choose ONE of the listed options\n"
+ " - Variant: sections: For nested unions, select appropriate variant and fill its fields\n"
+ " - const(value): Field must have exactly this value (no other options)\n\n"
+ "HOW TO GENERATE VALUES:\n"
+ "- Use the conversation context above to understand what the user wants\n"
+ "- Use the tool description to understand what this tool does\n"
+ "- Fill each field with appropriate values based on its schema description, the context and tool purpose\n"
+ "- Match the types and constraints shown in the schema\n"
+ "- For Union types, choose the appropriate variant based on context\n\n"
+ "REQUIREMENTS:\n"
+ "- Output ONLY raw JSON object (no markdown, no code blocks, no explanations)\n"
+ "- Fill fields with actual values matching the schema types\n"
+ "- All strings must be quoted, arrays in [], objects in {}\n"
+ "- Respect all constraints (min/max values, lengths, item counts, etc.)\n\n"
+ "\n"
+ '{"name": "John Doe", "age": 30, "tags": ["tag1", "tag2"], '
+ '"status": "active", "tool": "ToolA", "param1": "value1", "param2": 42}\n'
+ "\n\n"
+ "\n"
+ '{"name": {"type": "string", "description": "User\'s full name"}, '
+ '"age": {"type": "integer", "range": "18-120"}}\n'
+ "\n\n"
+ "\n"
+ '```json\n{"name": "John Doe", "age": 30}\n```\n'
+ "\n\n"
+ "The schema below shows the STRUCTURE in text format - provide VALUES in JSON that match it.\n\n"
+ )
+
+ def __init__(self, tool_class: type["BaseTool"]):
+ """Initialize tool instantiator.
+
+ Args:
+ tool_class: Tool class to instantiate
+ """
+ self.tool_class = tool_class
+ self.errors: list[str] = []
+ self.instance: "BaseTool | None" = None
+ self.input_content: str = ""
+
+ def _clearing_context(self, content: str) -> str:
+ """Extract JSON object from content by finding first { and last }.
+
+ Args:
+ content: Raw content that may contain JSON mixed with other text
+
+ Returns:
+ Extracted JSON string (from first { to last })
+ """
+ first_brace = content.find("{")
+ if first_brace == -1:
+ return content
+
+ last_brace = content.rfind("}")
+ if last_brace == -1 or last_brace < first_brace:
+ return content
+
+ return content[first_brace : last_brace + 1]
+
+ def _format_json_error(self, error: JSONDecodeError, content: str) -> str:
+ """Format JSON decode error with context around the error position.
+
+ Args:
+ error: JSONDecodeError exception
+ content: Content that failed to parse
+
+ Returns:
+ Formatted error message with context
+ """
+ error_pos = getattr(error, "pos", None)
+ if error_pos is None:
+ return f"Failed to parse JSON: {error}"
+
+ # Calculate context window (±20 characters)
+ context_size = 20
+ start = max(0, error_pos - context_size)
+ end = min(len(content), error_pos + context_size)
+
+ context_before = content[start:error_pos]
+ context_after = content[error_pos:end]
+
+ # Show position and context
+ error_msg = (
+ f"JSON parse error at position {error_pos}: {error.msg}\n"
+ f"Context: ...{context_before}[Error here]{context_after}..."
+ )
+
+ return error_msg
+
+ def generate_format_prompt(
+ self,
+ include_errors: bool = True,
+ ) -> str:
+ """Generate a prompt describing the expected format for LLM.
+
+ Args:
+ mode: Which parameters to include:
+ - "all": Show all parameters (default)
+ - "required": Show only required parameters
+ - "unfilled": Show only parameters where value is PydanticUndefined
+ include_errors: If True, include error messages for parameters with validation errors
+
+ Returns:
+ Format description string
+ """
+ simplified_schema = SchemaSimplifier.simplify(self.tool_class.model_json_schema())
+
+ prompt = (
+ f"\n"
+ f"Tool name: {self.tool_class.tool_name}\n"
+ f"Tool description: {self.tool_class.description}\n"
+ f"\n\n"
+ f"{self._FORMAT_PROMPT_TEMPLATE}"
+ f"\n"
+ f"{simplified_schema}\n"
+ f"\n\n"
+ )
+
+ # Show errors at the end if requested
+ if include_errors and self.errors:
+ prompt += "PREVIOUS FILLING ITERATION ERRORS. AVOID OR FIX IT:\n"
+ prompt += "fault content: \n" + self.input_content + "\n"
+ for error in self.errors:
+ prompt += f" - {error}\n"
+
+ return prompt
+
+ def build_model(self, content: str) -> "BaseTool":
+ """Build tool model instance from parsed parameters.
+
+ Args:
+ content: Raw content from LLM to parse
+
+ Returns:
+ Tool instance
+
+ Raises:
+ ValueError: If required parameters are missing or model creation fails
+ """
+ self.content = ""
+ self.errors.clear()
+ if not content:
+ self.errors.append("No content provided")
+ raise ValueError("No content provided")
+ self.input_content = content
+
+ try:
+ cleaned_content = self._clearing_context(content)
+ self.content = cleaned_content # Save cleaned content for debugging
+ self.instance = self.tool_class(**json.loads(cleaned_content))
+ return self.instance
+ except ValidationError as e:
+ for err in e.errors():
+ self.errors.append(
+ f"pydantic validation error - type: {err['type']} - " f"field: {err['loc']} - {err['msg']}"
+ )
+ raise ValueError("Failed to build model") from e
+ except JSONDecodeError as e:
+ error_content = cleaned_content if "cleaned_content" in locals() else self.content or content
+ error_msg = self._format_json_error(e, error_content)
+ self.errors.append(error_msg)
+ raise ValueError("Failed to build model") from e
+ except ValueError as e:
+ self.errors.append(f"Failed to parse JSON: {e}")
+ raise ValueError("Failed to build model") from e
diff --git a/sgr_agent_core/tools/__init__.py b/sgr_agent_core/tools/__init__.py
index ef859ffa..3f93988f 100644
--- a/sgr_agent_core/tools/__init__.py
+++ b/sgr_agent_core/tools/__init__.py
@@ -1,5 +1,9 @@
from sgr_agent_core.base_tool import BaseTool, MCPBaseTool
-from sgr_agent_core.next_step_tool import NextStepToolsBuilder, NextStepToolStub
+from sgr_agent_core.next_step_tool import (
+ NextStepToolsBuilder,
+ NextStepToolStub,
+ ToolNameSelectorStub,
+)
from sgr_agent_core.tools.adapt_plan_tool import AdaptPlanTool
from sgr_agent_core.tools.clarification_tool import ClarificationTool
from sgr_agent_core.tools.create_report_tool import CreateReportTool
@@ -14,6 +18,7 @@
"BaseTool",
"MCPBaseTool",
"NextStepToolStub",
+ "ToolNameSelectorStub",
"NextStepToolsBuilder",
# Individual tools
"ClarificationTool",
diff --git a/sgr_agent_core/tools/reasoning_tool.py b/sgr_agent_core/tools/reasoning_tool.py
index 5296d672..80f0e837 100644
--- a/sgr_agent_core/tools/reasoning_tool.py
+++ b/sgr_agent_core/tools/reasoning_tool.py
@@ -36,8 +36,7 @@ class ReasoningTool(BaseTool):
# Next step planning
remaining_steps: list[str] = Field(
- description="1-3 remaining steps (brief, action-oriented)",
- min_length=1,
+ description="0-3 remaining steps (brief, action-oriented)",
max_length=3,
)
task_completed: bool = Field(description="Is the research task finished?")
diff --git a/tests/test_iron_agent.py b/tests/test_iron_agent.py
new file mode 100644
index 00000000..1adbfee2
--- /dev/null
+++ b/tests/test_iron_agent.py
@@ -0,0 +1,331 @@
+"""Tests for IronAgent."""
+
+from unittest.mock import AsyncMock, Mock, patch
+
+import pytest
+from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage
+
+from sgr_agent_core.agents.iron_agent import IronAgent
+from sgr_agent_core.tools import ReasoningTool, WebSearchTool
+from tests.conftest import create_test_agent
+
+
+class TestIronAgent:
+ """Tests for IronAgent."""
+
+ def test_initialization(self):
+ """Test IronAgent initialization."""
+ agent = create_test_agent(
+ IronAgent,
+ task_messages=[{"role": "user", "content": "Test task"}],
+ toolkit=[ReasoningTool, WebSearchTool],
+ )
+
+ assert agent.name == "iron_agent"
+ assert len(agent.toolkit) == 2
+ assert ReasoningTool in agent.toolkit
+ assert WebSearchTool in agent.toolkit
+
+ @pytest.mark.asyncio
+ async def test_generate_tool_success(self):
+ """Test _generate_tool with successful parsing."""
+ agent = create_test_agent(
+ IronAgent,
+ toolkit=[ReasoningTool],
+ )
+
+ # Mock stream response
+ mock_chunk = Mock(spec=ChatCompletionChunk)
+ mock_chunk.type = "chunk"
+ mock_chunk.chunk = Mock() # Add chunk attribute
+
+ mock_completion = Mock(spec=ChatCompletionMessage)
+ mock_completion.content = (
+ '{"reasoning_steps": ["step1", "step2"], "current_situation": "test", '
+ '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], '
+ '"task_completed": false}'
+ )
+
+ # Create async iterator for stream
+ async def async_iter():
+ yield mock_chunk
+
+ mock_stream = AsyncMock()
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+ mock_stream.__aiter__ = lambda self: async_iter()
+ mock_final_completion = Mock()
+ mock_final_completion.choices = [Mock(message=mock_completion)]
+ mock_stream.get_final_completion = AsyncMock(return_value=mock_final_completion)
+
+ agent.openai_client.chat.completions.stream = Mock(return_value=mock_stream)
+
+ # Call method
+ result = await agent._generate_tool(ReasoningTool, [{"role": "user", "content": "test"}])
+
+ assert isinstance(result, ReasoningTool)
+ assert result.reasoning_steps == ["step1", "step2"]
+
+ # Check that logging was called for successful attempt
+ log_entries = [entry for entry in agent.log if entry.get("step_type") == "tool_generation_attempt"]
+ assert len(log_entries) == 1
+ assert log_entries[0]["success"] is True
+ assert log_entries[0]["tool_class"] == "ReasoningTool"
+ assert log_entries[0]["llm_content"] == mock_completion.content
+
+ @pytest.mark.asyncio
+ async def test_generate_tool_retry_on_error(self):
+ """Test _generate_tool with retry on parsing error."""
+ agent = create_test_agent(
+ IronAgent,
+ toolkit=[ReasoningTool],
+ )
+
+ # Mock stream response with invalid JSON first, then valid
+ mock_chunk = Mock(spec=ChatCompletionChunk)
+ mock_chunk.type = "chunk"
+ mock_chunk.chunk = Mock() # Add chunk attribute
+
+ # First attempt - invalid JSON
+ mock_completion_invalid = Mock(spec=ChatCompletionMessage)
+ mock_completion_invalid.content = "invalid json"
+
+ # Second attempt - valid JSON
+ mock_completion_valid = Mock(spec=ChatCompletionMessage)
+ mock_completion_valid.content = (
+ '{"reasoning_steps": ["step1", "step2"], "current_situation": "test", '
+ '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], '
+ '"task_completed": false}'
+ )
+
+ # Create async iterator for stream
+ async def async_iter():
+ yield mock_chunk
+
+ # First call returns invalid, second returns valid
+ call_count = {"count": 0}
+
+ async def get_completion():
+ call_count["count"] += 1
+ if call_count["count"] == 1:
+ return Mock(choices=[Mock(message=mock_completion_invalid)])
+ return Mock(choices=[Mock(message=mock_completion_valid)])
+
+ mock_stream = AsyncMock()
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+ mock_stream.__aiter__ = lambda self: async_iter()
+ mock_stream.get_final_completion = AsyncMock(side_effect=get_completion)
+
+ agent.openai_client.chat.completions.stream = Mock(return_value=mock_stream)
+
+ # Call method - should succeed on second attempt
+ result = await agent._generate_tool(ReasoningTool, [{"role": "user", "content": "test"}], max_retries=5)
+
+ assert isinstance(result, ReasoningTool)
+ assert call_count["count"] == 2 # Two attempts made
+
+ # Check that logging was called for both attempts
+ # Note: due to finally block, each attempt logs twice (in except and finally)
+ log_entries = [entry for entry in agent.log if entry.get("step_type") == "tool_generation_attempt"]
+ # First attempt: failed (logged in except and finally)
+ # Second attempt: succeeded (logged in finally)
+ assert len(log_entries) >= 2
+ # Find first failed attempt
+ failed_entries = [e for e in log_entries if e["success"] is False and e["attempt"] == 1]
+ assert len(failed_entries) >= 1
+ assert failed_entries[0]["llm_content"] == "invalid json"
+ # Find successful attempt
+ success_entries = [e for e in log_entries if e["success"] is True and e["attempt"] == 2]
+ assert len(success_entries) >= 1
+ assert success_entries[0]["llm_content"] == mock_completion_valid.content
+
+ @pytest.mark.asyncio
+ async def test_prepare_tools(self):
+ """Test _prepare_tools method."""
+ agent = create_test_agent(
+ IronAgent,
+ toolkit=[ReasoningTool, WebSearchTool],
+ )
+
+ # Call _prepare_tools
+ tool_selector_type = await agent._prepare_tools()
+
+ # Check that it returns a type that is a subclass of ToolNameSelectorStub
+ assert isinstance(tool_selector_type, type)
+ # Check that it's a dynamically created model (has __name__)
+ assert hasattr(tool_selector_type, "__name__")
+ # Check that it can be instantiated with required ReasoningTool fields
+ instance = tool_selector_type(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ function_name_choice="reasoningtool",
+ )
+ assert hasattr(instance, "function_name_choice")
+
+ @pytest.mark.asyncio
+ async def test_reasoning_phase(self):
+ """Test _reasoning_phase."""
+ agent = create_test_agent(
+ IronAgent,
+ toolkit=[ReasoningTool, WebSearchTool],
+ )
+
+ # Mock _generate_tool
+ # The result should be a ReasoningTool with tool_name field
+ from pydantic import create_model
+
+ # Create a mock reasoning tool with function_name_choice field
+ mock_reasoning_class = create_model(
+ "MockReasoning",
+ __base__=ReasoningTool,
+ function_name_choice=(str, "websearchtool"),
+ )
+ mock_reasoning = mock_reasoning_class(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ function_name_choice="websearchtool",
+ )
+
+ agent._generate_tool = AsyncMock(return_value=mock_reasoning)
+
+ # Call reasoning phase
+ result = await agent._reasoning_phase()
+
+ assert isinstance(result, ReasoningTool)
+ assert result.reasoning_steps == ["step1", "step2"]
+ assert hasattr(result, "function_name_choice")
+ agent._generate_tool.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_select_action_phase(self):
+ """Test _select_action_phase."""
+ agent = create_test_agent(
+ IronAgent,
+ toolkit=[ReasoningTool, WebSearchTool],
+ )
+
+ # Mock reasoning with function_name_choice field (already selected in reasoning phase)
+ from pydantic import create_model
+
+ mock_reasoning_class = create_model(
+ "MockReasoning",
+ __base__=ReasoningTool,
+ function_name_choice=(str, "websearchtool"),
+ )
+ reasoning = mock_reasoning_class(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["search"],
+ task_completed=False,
+ function_name_choice="websearchtool",
+ )
+
+ # Mock tool instance generation
+ mock_tool = WebSearchTool(reasoning="test reasoning", query="test query")
+ agent._generate_tool = AsyncMock(return_value=mock_tool)
+
+ # Call select action phase
+ result = await agent._select_action_phase(reasoning)
+
+ assert isinstance(result, WebSearchTool)
+ assert result.query == "test query"
+ # Only one call - tool parameter generation (tool_name already selected in reasoning)
+ agent._generate_tool.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_action_phase(self):
+ """Test _action_phase."""
+
+ agent = create_test_agent(
+ IronAgent,
+ toolkit=[WebSearchTool],
+ )
+
+ # Mock tool execution - need to patch __call__ method
+ tool = WebSearchTool(reasoning="test reasoning", query="test query")
+
+ # Patch WebSearchTool.__call__ to avoid TavilySearchService initialization
+ with patch(
+ "sgr_agent_core.tools.web_search_tool.WebSearchTool.__call__",
+ new_callable=AsyncMock,
+ return_value="Search results",
+ ) as mock_call:
+ # Call action phase
+ result = await agent._action_phase(tool)
+
+ assert result == "Search results"
+ mock_call.assert_called_once_with(agent._context, agent.config)
+
+ @pytest.mark.asyncio
+ async def test_log_tool_generation_attempt(self):
+ """Test _log_tool_instantiator logging."""
+ from sgr_agent_core.services.tool_instantiator import ToolInstantiator
+
+ agent = create_test_agent(
+ IronAgent,
+ toolkit=[ReasoningTool],
+ )
+
+ # Test successful attempt
+ instantiator = ToolInstantiator(ReasoningTool)
+ instantiator.input_content = (
+ '{"reasoning_steps": ["step1", "step2"], "current_situation": "test", '
+ '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], '
+ '"task_completed": false}'
+ )
+ instantiator.instance = ReasoningTool(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ )
+
+ agent._log_tool_instantiator(
+ instantiator=instantiator,
+ attempt=1,
+ max_retries=5,
+ )
+
+ # Check log entry
+ assert len(agent.log) == 1
+ log_entry = agent.log[0]
+ assert log_entry["step_type"] == "tool_generation_attempt"
+ assert log_entry["tool_class"] == "ReasoningTool"
+ assert log_entry["attempt"] == 1
+ assert log_entry["max_retries"] == 5
+ assert log_entry["success"] is True
+ assert instantiator.instance is not None
+ assert log_entry["llm_content"] == instantiator.input_content
+ assert log_entry["errors"] == []
+
+ # Test failed attempt
+ instantiator2 = ToolInstantiator(ReasoningTool)
+ instantiator2.input_content = "invalid json"
+ instantiator2.errors = ["JSON decode error", "Validation error"]
+
+ agent._log_tool_instantiator(
+ instantiator=instantiator2,
+ attempt=2,
+ max_retries=5,
+ )
+
+ # Check second log entry
+ assert len(agent.log) == 2
+ log_entry = agent.log[1]
+ assert log_entry["success"] is False
+ assert log_entry["llm_content"] == "invalid json"
+ assert log_entry["errors"] == ["JSON decode error", "Validation error"]
+ assert instantiator2.instance is None
diff --git a/tests/test_next_step_tool.py b/tests/test_next_step_tool.py
new file mode 100644
index 00000000..bd63429f
--- /dev/null
+++ b/tests/test_next_step_tool.py
@@ -0,0 +1,92 @@
+"""Tests for NextStepToolsBuilder."""
+
+import pytest
+from pydantic import ValidationError
+
+from sgr_agent_core.next_step_tool import NextStepToolsBuilder
+from sgr_agent_core.tools import ReasoningTool, WebSearchTool
+
+
+class TestNextStepToolsBuilder:
+ """Tests for NextStepToolsBuilder."""
+
+ def test_build_ToolNameSelector_single_tool(self):
+ """Test building ToolNameSelector with single tool."""
+ tools_list = [ReasoningTool]
+ selector_model = NextStepToolsBuilder.build_NextStepToolSelector(tools_list)
+
+ # Create instance with valid tool name and all required ReasoningTool fields
+ instance = selector_model(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ function_name_choice="reasoningtool",
+ )
+ assert instance.function_name_choice == "reasoningtool"
+
+ # Try invalid tool name - should raise ValidationError
+ with pytest.raises(ValidationError):
+ selector_model(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ tool_name="invalid_tool",
+ )
+
+ def test_build_ToolNameSelector_multiple_tools(self):
+ """Test building ToolNameSelector with multiple tools."""
+ tools_list = [ReasoningTool, WebSearchTool]
+ selector_model = NextStepToolsBuilder.build_NextStepToolSelector(tools_list)
+
+ # Create instance with valid tool names and all required ReasoningTool fields
+ instance1 = selector_model(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ function_name_choice="reasoningtool",
+ )
+ assert instance1.function_name_choice == "reasoningtool"
+
+ instance2 = selector_model(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ function_name_choice="websearchtool",
+ )
+ assert instance2.function_name_choice == "websearchtool"
+
+ # Try invalid tool name - should raise ValidationError
+ with pytest.raises(ValidationError):
+ selector_model(
+ reasoning_steps=["step1", "step2"],
+ current_situation="test",
+ plan_status="ok",
+ enough_data=False,
+ remaining_steps=["next"],
+ task_completed=False,
+ tool_name="invalid_tool",
+ )
+
+ def test_build_ToolNameSelector_includes_descriptions(self):
+ """Test that ToolNameSelector has proper field description."""
+ tools_list = [ReasoningTool, WebSearchTool]
+ selector_model = NextStepToolsBuilder.build_NextStepToolSelector(tools_list)
+
+ # Check that field description exists and is meaningful
+ field_info = selector_model.model_fields["function_name_choice"]
+ description = field_info.description
+
+ assert description is not None
+ assert len(description) > 0
diff --git a/tests/test_tool_instantiator.py b/tests/test_tool_instantiator.py
new file mode 100644
index 00000000..fbc3eca3
--- /dev/null
+++ b/tests/test_tool_instantiator.py
@@ -0,0 +1,761 @@
+"""Tests for ToolInstantiator service."""
+
+import json
+from json import JSONDecodeError
+
+import pytest
+
+from sgr_agent_core.services.tool_instantiator import SchemaSimplifier, ToolInstantiator
+from sgr_agent_core.tools import ReasoningTool, WebSearchTool
+
+
+class TestToolInstantiator:
+ """Test suite for ToolInstantiator."""
+
+ def test_initialization(self):
+ """Test ToolInstantiator initialization."""
+ instantiator = ToolInstantiator(ReasoningTool)
+
+ assert instantiator.tool_class == ReasoningTool
+ assert instantiator.errors == []
+ assert instantiator.instance is None
+ assert instantiator.input_content == ""
+
+ def test_generate_format_prompt_with_errors(self):
+ """Test generate_format_prompt with errors included."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ instantiator.errors = ["Error 1", "Error 2"]
+ instantiator.input_content = "invalid json"
+
+ prompt = instantiator.generate_format_prompt(include_errors=True)
+
+ assert "PREVIOUS FILLING ITERATION ERRORS" in prompt
+ assert "Error 1" in prompt
+ assert "Error 2" in prompt
+ assert "invalid json" in prompt
+
+ def test_generate_format_prompt_without_errors_when_errors_exist(self):
+ """Test generate_format_prompt with include_errors=False when errors
+ exist."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ instantiator.errors = ["Error 1"]
+ instantiator.input_content = "invalid json"
+
+ prompt = instantiator.generate_format_prompt(include_errors=False)
+
+ assert "PREVIOUS FILLING ITERATION ERRORS" not in prompt
+ assert "Error 1" not in prompt
+
+ def test_build_model_success(self):
+ """Test build_model with valid JSON content."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = json.dumps(
+ {
+ "reasoning_steps": ["step1", "step2"],
+ "current_situation": "test",
+ "plan_status": "ok",
+ "enough_data": False,
+ "remaining_steps": ["next"],
+ "task_completed": False,
+ }
+ )
+
+ result = instantiator.build_model(content)
+
+ assert isinstance(result, ReasoningTool)
+ assert result == instantiator.instance
+ assert instantiator.instance is not None
+ assert instantiator.input_content == content
+ assert len(instantiator.errors) == 0
+ assert result.reasoning_steps == ["step1", "step2"]
+ assert result.current_situation == "test"
+
+ def test_build_model_with_whitespace(self):
+ """Test build_model handles whitespace correctly."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = (
+ ' \n{"reasoning_steps": ["step1", "step2"], "current_situation": "test", '
+ '"plan_status": "ok", "enough_data": false, "remaining_steps": ["next"], '
+ '"task_completed": false}\n '
+ )
+
+ result = instantiator.build_model(content)
+
+ assert isinstance(result, ReasoningTool)
+ assert instantiator.instance is not None
+
+ def test_build_model_empty_content(self):
+ """Test build_model with empty content."""
+ instantiator = ToolInstantiator(ReasoningTool)
+
+ with pytest.raises(ValueError, match="No content provided"):
+ instantiator.build_model("")
+
+ assert len(instantiator.errors) == 1
+ assert "No content provided" in instantiator.errors
+ assert instantiator.instance is None
+
+ def test_build_model_invalid_json(self):
+ """Test build_model with invalid JSON."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = "invalid json content"
+
+ with pytest.raises(ValueError, match="Failed to build model"):
+ instantiator.build_model(content)
+
+ assert len(instantiator.errors) > 0
+ # Check for JSON parse error message (new format with context)
+ assert any("JSON parse error" in err or "Failed to parse JSON" in err for err in instantiator.errors)
+ assert instantiator.input_content == content
+ assert instantiator.instance is None
+
+ def test_build_model_validation_error(self):
+ """Test build_model with Pydantic validation error."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ # Missing required fields
+ content = json.dumps({"reasoning_steps": ["step1"]})
+
+ with pytest.raises(ValueError, match="Failed to build model"):
+ instantiator.build_model(content)
+
+ assert len(instantiator.errors) > 0
+ assert any("pydantic validation error" in err for err in instantiator.errors)
+ assert instantiator.input_content == content
+ assert instantiator.instance is None
+
+ def test_build_model_clears_errors_on_new_attempt(self):
+ """Test that build_model clears errors before new attempt."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ instantiator.errors = ["Previous error"]
+
+ # First attempt fails
+ try:
+ instantiator.build_model("invalid")
+ except ValueError:
+ pass
+
+ assert len(instantiator.errors) > 0
+
+ # Second attempt succeeds - errors should be cleared
+ valid_content = json.dumps(
+ {
+ "reasoning_steps": ["step1", "step2"],
+ "current_situation": "test",
+ "plan_status": "ok",
+ "enough_data": False,
+ "remaining_steps": ["next"],
+ "task_completed": False,
+ }
+ )
+ result = instantiator.build_model(valid_content)
+
+ assert isinstance(result, ReasoningTool)
+ assert instantiator.instance is not None
+
+ def test_build_model_with_web_search_tool(self):
+ """Test build_model with different tool class."""
+ instantiator = ToolInstantiator(WebSearchTool)
+ content = json.dumps({"reasoning": "test reasoning", "query": "test query"})
+
+ result = instantiator.build_model(content)
+
+ assert isinstance(result, WebSearchTool)
+ assert result.reasoning == "test reasoning"
+ assert result.query == "test query"
+ assert instantiator.instance == result
+
+ def test_generate_format_prompt_includes_schema(self):
+ """Test that generate_format_prompt includes simplified schema."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ prompt = instantiator.generate_format_prompt()
+
+ # Check that simplified schema format is present
+ assert "" in prompt
+ assert "" in prompt
+ # Check for simplified format markers
+ assert "reasoning_steps (required" in prompt
+ assert "list[string]" in prompt
+ assert "enough_data (optional" in prompt
+
+ def test_errors_accumulation(self):
+ """Test that errors accumulate across multiple failed attempts."""
+ instantiator = ToolInstantiator(ReasoningTool)
+
+ # First attempt - invalid JSON
+ try:
+ instantiator.build_model("invalid json 1")
+ except ValueError:
+ pass
+
+ # Second attempt - invalid JSON
+ try:
+ instantiator.build_model("invalid json 2")
+ except ValueError:
+ pass
+
+ # Note: build_model clears errors at start, so each attempt starts fresh
+ # But we can check that errors are added during each attempt
+ assert len(instantiator.errors) > 0
+
+ def test_clearing_context_extracts_json(self):
+ """Test _clearing_context extracts JSON from mixed content."""
+ instantiator = ToolInstantiator(ReasoningTool)
+
+ # Content with text before and after JSON
+ content = 'Some text before {"reasoning_steps": ["step1"], "current_situation": "test"} and after'
+ result = instantiator._clearing_context(content)
+
+ assert result.startswith("{")
+ assert result.endswith("}")
+ assert "reasoning_steps" in result
+ assert "Some text before" not in result
+ assert "and after" not in result
+
+ def test_clearing_context_no_braces(self):
+ """Test _clearing_context returns original content if no braces
+ found."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = "no json here"
+
+ result = instantiator._clearing_context(content)
+
+ assert result == content
+
+ def test_clearing_context_only_opening_brace(self):
+ """Test _clearing_context handles case with only opening brace."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = "text { incomplete json"
+
+ result = instantiator._clearing_context(content)
+
+ assert result == content
+
+ def test_build_model_with_text_around_json(self):
+ """Test build_model extracts JSON from text with surrounding
+ content."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ json_data = {
+ "reasoning_steps": ["step1", "step2"],
+ "current_situation": "test",
+ "plan_status": "ok",
+ "enough_data": False,
+ "remaining_steps": ["next"],
+ "task_completed": False,
+ }
+ content = f"Here is the JSON: {json.dumps(json_data)} and some text after"
+
+ result = instantiator.build_model(content)
+
+ assert isinstance(result, ReasoningTool)
+ assert result.reasoning_steps == ["step1", "step2"]
+
+ def test_format_json_error_with_context(self):
+ """Test _format_json_error includes context around error position."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ # Create invalid JSON with extra data
+ content = '{"field": "value"} extra data here'
+ try:
+ json.loads(content)
+ except JSONDecodeError as e:
+ error_msg = instantiator._format_json_error(e, content)
+
+ assert "JSON parse error" in error_msg
+ assert "position" in error_msg
+ assert "Context:" in error_msg
+ assert "extra data" in error_msg or "value" in error_msg
+
+ def test_format_json_error_without_position(self):
+ """Test _format_json_error handles error without position attribute."""
+ instantiator = ToolInstantiator(ReasoningTool)
+
+ # Create error-like object without pos attribute
+ class ErrorWithoutPos:
+ def __init__(self):
+ self.msg = "test error"
+
+ error = ErrorWithoutPos()
+ error_msg = instantiator._format_json_error(error, "test content")
+
+ # Should fall back to generic error message when pos is missing
+ assert "Failed to parse JSON" in error_msg
+
+ def test_format_json_error_at_start(self):
+ """Test _format_json_error handles error at start of content."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = 'invalid {"field": "value"}'
+ try:
+ json.loads(content)
+ except JSONDecodeError as e:
+ error_msg = instantiator._format_json_error(e, content)
+ assert "JSON parse error" in error_msg
+ assert "position" in error_msg
+
+ def test_format_json_error_at_end(self):
+ """Test _format_json_error handles error at end of content."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = '{"field": "value"} invalid'
+ try:
+ json.loads(content)
+ except JSONDecodeError as e:
+ error_msg = instantiator._format_json_error(e, content)
+ assert "JSON parse error" in error_msg
+ assert "position" in error_msg
+
+ def test_generate_format_prompt_includes_tool_info(self):
+ """Test that generate_format_prompt includes tool name."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ prompt = instantiator.generate_format_prompt()
+
+ assert "" in prompt
+ assert "" in prompt
+ assert ReasoningTool.tool_name in prompt
+
+ def test_generate_format_prompt_includes_format_template(self):
+ """Test that generate_format_prompt includes important format
+ sections."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ prompt = instantiator.generate_format_prompt()
+
+ # Check that key sections are present (not checking exact strings to avoid hardcoding)
+ assert len(prompt) > 100 # Should be substantial
+ assert "" in prompt # Schema section is critical
+ assert "" in prompt # Tool info is critical
+ # Check that it contains instructions (not just checking exact strings)
+ assert any(keyword in prompt.lower() for keyword in ["json", "schema", "format", "required"])
+
+ def test_clearing_context_multiple_braces(self):
+ """Test _clearing_context extracts JSON when multiple braces exist."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ content = 'text { {"reasoning_steps": ["step1"], "current_situation": "test"} } more text'
+ result = instantiator._clearing_context(content)
+
+ assert result.startswith("{")
+ assert result.endswith("}")
+ assert "reasoning_steps" in result
+
+ def test_clearing_context_nested_objects(self):
+ """Test _clearing_context handles nested JSON objects."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ nested_json = {"outer": {"inner": "value"}}
+ content = f"Text {json.dumps(nested_json)} more text"
+ result = instantiator._clearing_context(content)
+
+ assert result.startswith("{")
+ assert result.endswith("}")
+ assert "outer" in result
+ assert "inner" in result
+
+ def test_build_model_clears_content_attribute(self):
+ """Test that build_model clears and sets content attribute
+ correctly."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ instantiator.content = "previous content"
+
+ valid_content = json.dumps(
+ {
+ "reasoning_steps": ["step1", "step2"],
+ "current_situation": "test",
+ "plan_status": "ok",
+ "enough_data": False,
+ "remaining_steps": ["next"],
+ "task_completed": False,
+ }
+ )
+ instantiator.build_model(valid_content)
+
+ # Content should be set to cleaned content (same as input in this case)
+ assert instantiator.content == valid_content
+ assert instantiator.content != "previous content"
+
+ def test_build_model_with_markdown_code_block(self):
+ """Test build_model extracts JSON from markdown code block."""
+ instantiator = ToolInstantiator(ReasoningTool)
+ json_data = {
+ "reasoning_steps": ["step1", "step2"],
+ "current_situation": "test",
+ "plan_status": "ok",
+ "enough_data": False,
+ "remaining_steps": ["next"],
+ "task_completed": False,
+ }
+ content = f"```json\n{json.dumps(json_data)}\n```"
+
+ result = instantiator.build_model(content)
+
+ assert isinstance(result, ReasoningTool)
+ assert result.reasoning_steps == ["step1", "step2"]
+
+
+class TestSchemaSimplifier:
+ """Test suite for SchemaSimplifier."""
+
+ def test_simplify_simple_string_field(self):
+ """Test simplifying schema with simple string field."""
+ schema = {
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "User name",
+ }
+ },
+ "required": ["name"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "name (required, string): User name" in result
+
+ def test_simplify_optional_field(self):
+ """Test simplifying schema with optional field."""
+ schema = {
+ "properties": {
+ "name": {"type": "string", "description": "User name"},
+ "age": {"type": "integer", "description": "User age"},
+ },
+ "required": ["name"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "name (required" in result
+ assert "age (optional" in result
+
+ def test_simplify_array_field(self):
+ """Test simplifying schema with array field."""
+ schema = {
+ "properties": {
+ "tags": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "List of tags",
+ }
+ },
+ "required": ["tags"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "tags (required" in result
+ assert "list[string]" in result
+
+ def test_simplify_array_with_constraints(self):
+ """Test simplifying schema with array constraints."""
+ schema = {
+ "properties": {
+ "items": {
+ "type": "array",
+ "items": {"type": "string"},
+ "minItems": 1,
+ "maxItems": 5,
+ "description": "List of items",
+ }
+ },
+ "required": ["items"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "items (required" in result
+ assert "1-5 items" in result
+
+ def test_simplify_string_with_length_constraints(self):
+ """Test simplifying schema with string length constraints."""
+ schema = {
+ "properties": {
+ "text": {
+ "type": "string",
+ "minLength": 10,
+ "maxLength": 100,
+ "description": "Text field",
+ }
+ },
+ "required": ["text"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "text (required" in result
+ assert "length: 10-100" in result
+
+ def test_simplify_integer_with_range_constraints(self):
+ """Test simplifying schema with integer range constraints."""
+ schema = {
+ "properties": {
+ "age": {
+ "type": "integer",
+ "minimum": 18,
+ "maximum": 120,
+ "description": "User age",
+ }
+ },
+ "required": ["age"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "age (required" in result
+ assert "range: 18-120" in result
+
+ def test_simplify_field_with_default(self):
+ """Test simplifying schema with default value."""
+ schema = {
+ "properties": {
+ "status": {
+ "type": "string",
+ "default": "active",
+ "description": "Status",
+ }
+ },
+ "required": [],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "status (optional" in result
+ assert 'default: "active"' in result
+
+ def test_simplify_enum_field(self):
+ """Test simplifying schema with enum (Literal)."""
+ schema = {
+ "properties": {
+ "status": {
+ "type": "string",
+ "enum": ["active", "inactive", "pending"],
+ "description": "Status",
+ }
+ },
+ "required": ["status"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "status (required" in result
+ assert "Literal[" in result
+ assert '"active"' in result or "'active'" in result
+
+ def test_simplify_const_field(self):
+ """Test simplifying schema with const value."""
+ schema = {
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "user",
+ "description": "Type",
+ }
+ },
+ "required": ["type"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "type (required" in result
+ assert "const:" in result
+
+ def test_simplify_nested_object(self):
+ """Test simplifying schema with nested object."""
+ schema = {
+ "properties": {
+ "address": {
+ "type": "object",
+ "properties": {
+ "street": {"type": "string", "description": "Street"},
+ "city": {"type": "string", "description": "City"},
+ },
+ "required": ["street"],
+ "description": "Address",
+ }
+ },
+ "required": ["address"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "address (required" in result
+ assert "street (required" in result
+ assert "city (optional" in result
+
+ def test_simplify_anyof_with_ref(self):
+ """Test simplifying schema with anyOf containing $ref."""
+ schema = {
+ "properties": {
+ "tool": {
+ "anyOf": [
+ {"$ref": "#/$defs/ToolA"},
+ {"$ref": "#/$defs/ToolB"},
+ ],
+ "description": "Select tool",
+ }
+ },
+ "required": ["tool"],
+ "$defs": {
+ "ToolA": {
+ "title": "ToolA",
+ "properties": {
+ "param1": {"type": "string", "description": "Param 1"},
+ },
+ "required": ["param1"],
+ },
+ "ToolB": {
+ "title": "ToolB",
+ "properties": {
+ "param2": {"type": "integer", "description": "Param 2"},
+ },
+ "required": ["param2"],
+ },
+ },
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "tool (required" in result
+ assert "Literal[" in result
+ assert "ToolA" in result
+ assert "ToolB" in result
+ assert "Variant: ToolA" in result
+ assert "param1 (required" in result
+ assert "Variant: ToolB" in result
+ assert "param2 (required" in result
+
+ def test_simplify_anyof_with_const(self):
+ """Test simplifying schema with anyOf containing const values."""
+ schema = {
+ "properties": {
+ "status": {
+ "anyOf": [
+ {"const": "active"},
+ {"const": "inactive"},
+ ],
+ "description": "Status",
+ }
+ },
+ "required": ["status"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert "status (required" in result
+ assert "const(" in result
+ assert "active" in result
+ assert "inactive" in result
+
+ def test_simplify_empty_properties(self):
+ """Test simplifying schema with no properties."""
+ schema = {"properties": {}, "required": []}
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert result == ""
+
+ def test_simplify_no_properties_key(self):
+ """Test simplifying schema without properties key."""
+ schema = {"required": []}
+
+ result = SchemaSimplifier.simplify(schema)
+
+ assert result == ""
+
+ def test_simplify_fields_sorted_required_first(self):
+ """Test that required fields come before optional fields."""
+ schema = {
+ "properties": {
+ "optional_field": {"type": "string", "description": "Optional"},
+ "required_field": {"type": "string", "description": "Required"},
+ },
+ "required": ["required_field"],
+ }
+
+ result = SchemaSimplifier.simplify(schema)
+ lines = result.split("\n")
+
+ required_index = next(i for i, line in enumerate(lines) if "required_field" in line)
+ optional_index = next(i for i, line in enumerate(lines) if "optional_field" in line)
+
+ assert required_index < optional_index
+
+ def test_simplify_with_indent(self):
+ """Test simplifying schema with indentation."""
+ schema = {
+ "properties": {
+ "nested": {
+ "type": "object",
+ "properties": {
+ "field": {"type": "string", "description": "Field"},
+ },
+ "required": ["field"],
+ "description": "Nested",
+ }
+ },
+ "required": ["nested"],
+ }
+
+ result = SchemaSimplifier.simplify(schema, indent=1)
+
+ lines = result.split("\n")
+ nested_line = next(line for line in lines if "nested (required" in line)
+ nested_field_line = next(line for line in lines if "field" in line and "Field" in line)
+ # Check that nested field has more indentation than parent
+ assert len(nested_field_line) - len(nested_field_line.lstrip()) > len(nested_line) - len(nested_line.lstrip())
+
+ def test_extract_type_unknown(self):
+ """Test _extract_type returns 'unknown' for unknown type."""
+ schema = {}
+
+ result = SchemaSimplifier._extract_type(schema)
+
+ assert result == "unknown"
+
+ def test_extract_constraints_empty(self):
+ """Test _extract_constraints returns empty list when no constraints."""
+ schema = {"type": "string"}
+
+ result = SchemaSimplifier._extract_constraints(schema)
+
+ assert result == []
+
+ def test_extract_constraints_minimum_only(self):
+ """Test _extract_constraints with only minimum."""
+ schema = {"type": "integer", "minimum": 0}
+
+ result = SchemaSimplifier._extract_constraints(schema)
+
+ assert "min: 0" in result
+
+ def test_extract_constraints_maximum_only(self):
+ """Test _extract_constraints with only maximum."""
+ schema = {"type": "integer", "maximum": 100}
+
+ result = SchemaSimplifier._extract_constraints(schema)
+
+ assert "max: 100" in result
+
+ def test_extract_constraints_min_length_only(self):
+ """Test _extract_constraints with only minLength."""
+ schema = {"type": "string", "minLength": 5}
+
+ result = SchemaSimplifier._extract_constraints(schema)
+
+ assert "min length: 5" in result
+
+ def test_extract_constraints_max_length_only(self):
+ """Test _extract_constraints with only maxLength."""
+ schema = {"type": "string", "maxLength": 100}
+
+ result = SchemaSimplifier._extract_constraints(schema)
+
+ assert "max length: 100" in result
+
+ def test_extract_constraints_min_items_only(self):
+ """Test _extract_constraints with only minItems."""
+ schema = {"type": "array", "minItems": 1}
+
+ result = SchemaSimplifier._extract_constraints(schema)
+
+ assert "min 1 items" in result
+
+ def test_extract_constraints_max_items_only(self):
+ """Test _extract_constraints with only maxItems."""
+ schema = {"type": "array", "maxItems": 10}
+
+ result = SchemaSimplifier._extract_constraints(schema)
+
+ assert "max 10 items" in result