diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index c6b1af702..69578cb5d 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,7 +3,9 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ +import asyncio from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Any, Union @@ -86,7 +88,12 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> M """Invoke asynchronously.""" raise NotImplementedError("invoke_async not implemented") - @abstractmethod def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: """Invoke synchronously.""" - raise NotImplementedError("__call__ not implemented") + + def execute() -> MultiAgentResult: + return asyncio.run(self.invoke_async(task, **kwargs)) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 7aa76bb90..395d9275c 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -141,9 +141,35 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) - def __call__(self, task: str) -> MultiAgentResult: - return MultiAgentResult(results={}) - - # Should not raise an exception + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) + + +def test_multi_agent_base_call_method(): + """Test that __call__ method properly delegates to invoke_async.""" + + class TestMultiAgent(MultiAgentBase): + def __init__(self): + self.invoke_async_called = False + self.received_task = None + self.received_kwargs = None + + async def invoke_async(self, task, **kwargs): + self.invoke_async_called = True + self.received_task = task + self.received_kwargs = kwargs + return MultiAgentResult( + status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} + ) + + agent = TestMultiAgent() + + # Test with string task + result = agent("test task", param1="value1", param2="value2") + + assert agent.invoke_async_called + assert agent.received_task == "test task" + assert agent.received_kwargs == {"param1": "value1", "param2": "value2"} + assert isinstance(result, MultiAgentResult) + assert result.status == Status.COMPLETED