Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sgr_agent_core/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _log_reasoning(self, result: ReasoningTool) -> None:
"step_number": self._context.iteration,
"timestamp": datetime.now().isoformat(),
"step_type": "reasoning",
"agent_reasoning": result.model_dump(),
"agent_reasoning": result.model_dump(mode="json"),
}
)

Expand All @@ -108,7 +108,7 @@ def _log_tool_execution(self, tool: BaseTool, result: str):
"timestamp": datetime.now().isoformat(),
"step_type": "tool_execution",
"tool_name": tool.tool_name,
"agent_tool_context": tool.model_dump(),
"agent_tool_context": tool.model_dump(mode="json"),
"agent_tool_execution_result": result,
}
)
Expand All @@ -127,7 +127,7 @@ def _save_agent_log(self):
agent_log = {
"id": self.id,
"model_config": self.config.llm.model_dump(
exclude={"api_key", "proxy"}
exclude={"api_key", "proxy"}, mode="json"
), # Sensitive data excluded by default
"task": self.task,
"toolkit": [tool.tool_name for tool in self.toolkit],
Expand Down
2 changes: 1 addition & 1 deletion sgr_agent_core/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MCPBaseTool(BaseTool):

async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs) -> str:
config = GlobalConfig()
payload = self.model_dump()
payload = self.model_dump(mode="json")
try:
async with self._client:
result = await self._client.call_tool(self.tool_name, payload)
Expand Down
83 changes: 83 additions & 0 deletions tests/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,89 @@ def test_log_tool_execution_contains_result(self):
log_entry = agent.log[0]
assert log_entry["agent_tool_execution_result"] == result

def test_log_tool_execution_with_enum_serializes_correctly(self):
"""Test that tool with enum field serializes correctly to JSON."""
import json
from enum import Enum

class TestPriority(Enum):
LOW = 1
MEDIUM = 2
HIGH = 3

class TestToolWithEnum(BaseTool):
name: str = "test_tool"
priority: TestPriority

agent = create_test_agent(BaseAgent, task="Test")
agent._context.iteration = 1

tool = TestToolWithEnum(priority=TestPriority.HIGH)

agent._log_tool_execution(tool, "Tool result")

log_entry = agent.log[0]
assert "agent_tool_context" in log_entry

tool_context = log_entry["agent_tool_context"]
assert isinstance(tool_context["priority"], int)
assert tool_context["priority"] == 3

json_str = json.dumps(log_entry, ensure_ascii=False)
assert "3" in json_str
assert '"priority": 3' in json_str

def test_save_agent_log_with_enum_serializes_correctly(self, tmp_path):
"""Test that agent log with enum serializes correctly to JSON file."""
import json
import os
from enum import Enum
from unittest.mock import patch

from sgr_agent_core.agent_definition import ExecutionConfig

class TestStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
DONE = "done"

class TestToolWithEnum(BaseTool):
name: str = "test_tool"
status: TestStatus

logs_dir = str(tmp_path / "logs")

agent = create_test_agent(
BaseAgent,
task="Test",
execution_config=ExecutionConfig(
max_iterations=20,
max_clarifications=3,
logs_dir=logs_dir,
),
)

tool = TestToolWithEnum(status=TestStatus.DONE)
agent._log_tool_execution(tool, "Result")

mock_config = Mock()
mock_config.execution.logs_dir = logs_dir

with patch("sgr_agent_core.agent_config.GlobalConfig", return_value=mock_config):
agent._save_agent_log()

assert os.path.exists(logs_dir)
log_files = list(os.listdir(logs_dir))
assert len(log_files) == 1

log_file_path = os.path.join(logs_dir, log_files[0])
with open(log_file_path, "r", encoding="utf-8") as f:
log_data = json.load(f)

tool_context = log_data["log"][0]["agent_tool_context"]
assert isinstance(tool_context["status"], str)
assert tool_context["status"] == "done"


class TestBaseAgentAbstractMethods:
"""Tests for abstract methods that must be implemented by subclasses."""
Expand Down