Skip to content

Commit 6ef6447

Browse files
authored
feat(multiagent): Add __call__ implementation to MultiAgentBase (#645)
1 parent 9397f58 commit 6ef6447

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

src/strands/multiagent/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
44
"""
55

6+
import asyncio
67
from abc import ABC, abstractmethod
8+
from concurrent.futures import ThreadPoolExecutor
79
from dataclasses import dataclass, field
810
from enum import Enum
911
from typing import Any, Union
@@ -86,7 +88,12 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> M
8688
"""Invoke asynchronously."""
8789
raise NotImplementedError("invoke_async not implemented")
8890

89-
@abstractmethod
9091
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
9192
"""Invoke synchronously."""
92-
raise NotImplementedError("__call__ not implemented")
93+
94+
def execute() -> MultiAgentResult:
95+
return asyncio.run(self.invoke_async(task, **kwargs))
96+
97+
with ThreadPoolExecutor() as executor:
98+
future = executor.submit(execute)
99+
return future.result()

tests/strands/multiagent/test_base.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,35 @@ class CompleteMultiAgent(MultiAgentBase):
141141
async def invoke_async(self, task: str) -> MultiAgentResult:
142142
return MultiAgentResult(results={})
143143

144-
def __call__(self, task: str) -> MultiAgentResult:
145-
return MultiAgentResult(results={})
146-
147-
# Should not raise an exception
144+
# Should not raise an exception - __call__ is provided by base class
148145
agent = CompleteMultiAgent()
149146
assert isinstance(agent, MultiAgentBase)
147+
148+
149+
def test_multi_agent_base_call_method():
150+
"""Test that __call__ method properly delegates to invoke_async."""
151+
152+
class TestMultiAgent(MultiAgentBase):
153+
def __init__(self):
154+
self.invoke_async_called = False
155+
self.received_task = None
156+
self.received_kwargs = None
157+
158+
async def invoke_async(self, task, **kwargs):
159+
self.invoke_async_called = True
160+
self.received_task = task
161+
self.received_kwargs = kwargs
162+
return MultiAgentResult(
163+
status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)}
164+
)
165+
166+
agent = TestMultiAgent()
167+
168+
# Test with string task
169+
result = agent("test task", param1="value1", param2="value2")
170+
171+
assert agent.invoke_async_called
172+
assert agent.received_task == "test task"
173+
assert agent.received_kwargs == {"param1": "value1", "param2": "value2"}
174+
assert isinstance(result, MultiAgentResult)
175+
assert result.status == Status.COMPLETED

0 commit comments

Comments
 (0)