Skip to content

Commit e4bd3bc

Browse files
authored
fix: a2a use artifact update event (#1401)
fix: update tests fix: simplify code by storing in class fix: remove uneeded code change fix: hide a2a artifact streaming under feature flag fix: use walrus operator fix: use star to signify end of unnamed fix: add check for walrus legacy fix: clarify enable_a2a_compliant_streaming parameter in StrandsA2AExecutor initialization fix: update tests refactor: streamline artifact addition logic in StrandsA2AExecutor
1 parent bce2464 commit e4bd3bc

File tree

3 files changed

+172
-13
lines changed

3 files changed

+172
-13
lines changed

src/strands/multiagent/a2a/executor.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import json
1313
import logging
1414
import mimetypes
15+
import uuid
16+
import warnings
1517
from typing import Any, Literal
1618

1719
from a2a.server.agent_execution import AgentExecutor, RequestContext
@@ -49,13 +51,21 @@ class StrandsA2AExecutor(AgentExecutor):
4951
# Handle special cases where format differs from extension
5052
FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"}
5153

52-
def __init__(self, agent: SAAgent):
54+
# A2A-compliant streaming mode
55+
_current_artifact_id: str | None
56+
_is_first_chunk: bool
57+
58+
def __init__(self, agent: SAAgent, *, enable_a2a_compliant_streaming: bool = False):
5359
"""Initialize a StrandsA2AExecutor.
5460
5561
Args:
5662
agent: The Strands Agent instance to adapt to the A2A protocol.
63+
enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with
64+
artifact updates. If False, uses legacy status updates streaming behavior
65+
for backwards compatibility. Defaults to False.
5766
"""
5867
self.agent = agent
68+
self.enable_a2a_compliant_streaming = enable_a2a_compliant_streaming
5969

6070
async def execute(
6171
self,
@@ -104,12 +114,30 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
104114
else:
105115
raise ValueError("No content blocks available")
106116

117+
if not self.enable_a2a_compliant_streaming:
118+
warnings.warn(
119+
"The default A2A response stream implemented in the strands sdk does not conform to "
120+
"what is expected in the A2A spec. Please set the `enable_a2a_compliant_streaming` "
121+
"boolean to `True` on your `A2AServer` class to properly conform to the spec. "
122+
"In the next major version release, this will be the default behavior.",
123+
UserWarning,
124+
stacklevel=3,
125+
)
126+
127+
if self.enable_a2a_compliant_streaming:
128+
self._current_artifact_id = str(uuid.uuid4())
129+
self._is_first_chunk = True
130+
107131
try:
108132
async for event in self.agent.stream_async(content_blocks):
109133
await self._handle_streaming_event(event, updater)
110134
except Exception:
111135
logger.exception("Error in streaming execution")
112136
raise
137+
finally:
138+
if self.enable_a2a_compliant_streaming:
139+
self._current_artifact_id = None
140+
self._is_first_chunk = True
113141

114142
async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None:
115143
"""Handle a single streaming event from the Strands Agent.
@@ -125,28 +153,60 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda
125153
logger.debug("Streaming event: %s", event)
126154
if "data" in event:
127155
if text_content := event["data"]:
128-
await updater.update_status(
129-
TaskState.working,
130-
new_agent_text_message(
131-
text_content,
132-
updater.context_id,
133-
updater.task_id,
134-
),
135-
)
156+
if self.enable_a2a_compliant_streaming:
157+
await updater.add_artifact(
158+
[Part(root=TextPart(text=text_content))],
159+
artifact_id=self._current_artifact_id,
160+
name="agent_response",
161+
append=not self._is_first_chunk,
162+
)
163+
self._is_first_chunk = False
164+
else:
165+
# Legacy use update_status with agent message
166+
await updater.update_status(
167+
TaskState.working,
168+
new_agent_text_message(
169+
text_content,
170+
updater.context_id,
171+
updater.task_id,
172+
),
173+
)
136174
elif "result" in event:
137175
await self._handle_agent_result(event["result"], updater)
138176

139177
async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None:
140178
"""Handle the final result from the Strands Agent.
141179
142-
Processes the agent's final result, extracts text content from the response,
143-
and adds it as an artifact to the task before marking the task as complete.
180+
For A2A-compliant streaming: sends the final artifact chunk marker and marks
181+
the task as complete. If no data chunks were previously sent, includes the
182+
result content.
183+
184+
For legacy streaming: adds the final result as a simple artifact without
185+
artifact_id tracking.
144186
145187
Args:
146188
result: The agent result object containing the final response, or None if no result.
147189
updater: The task updater for managing task state and adding the final artifact.
148190
"""
149-
if final_content := str(result):
191+
if self.enable_a2a_compliant_streaming:
192+
if self._is_first_chunk:
193+
final_content = str(result) if result else ""
194+
parts = [Part(root=TextPart(text=final_content))] if final_content else []
195+
await updater.add_artifact(
196+
parts,
197+
artifact_id=self._current_artifact_id,
198+
name="agent_response",
199+
last_chunk=True,
200+
)
201+
else:
202+
await updater.add_artifact(
203+
[],
204+
artifact_id=self._current_artifact_id,
205+
name="agent_response",
206+
append=True,
207+
last_chunk=True,
208+
)
209+
elif final_content := str(result):
150210
await updater.add_artifact(
151211
[Part(root=TextPart(text=final_content))],
152212
name="agent_response",

src/strands/multiagent/a2a/server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
queue_manager: QueueManager | None = None,
4343
push_config_store: PushNotificationConfigStore | None = None,
4444
push_sender: PushNotificationSender | None = None,
45+
enable_a2a_compliant_streaming: bool = False,
4546
):
4647
"""Initialize an A2A-compatible server from a Strands agent.
4748
@@ -66,6 +67,9 @@ def __init__(
6667
no push notification configuration is used.
6768
push_sender: Custom push notification sender implementation. If None,
6869
no push notifications are sent.
70+
enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with
71+
artifact updates. If False, uses legacy status updates streaming behavior
72+
for backwards compatibility. Defaults to False.
6973
"""
7074
self.host = host
7175
self.port = port
@@ -90,7 +94,9 @@ def __init__(
9094
self.description = self.strands_agent.description
9195
self.capabilities = AgentCapabilities(streaming=True)
9296
self.request_handler = DefaultRequestHandler(
93-
agent_executor=StrandsA2AExecutor(self.strands_agent),
97+
agent_executor=StrandsA2AExecutor(
98+
self.strands_agent, enable_a2a_compliant_streaming=enable_a2a_compliant_streaming
99+
),
94100
task_store=task_store or InMemoryTaskStore(),
95101
queue_manager=queue_manager,
96102
push_config_store=push_config_store,

tests/strands/multiagent/a2a/test_executor.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,3 +1020,96 @@ def test_default_formats_modularization():
10201020
assert executor._get_file_format_from_mime_type("", "document") == "txt"
10211021
assert executor._get_file_format_from_mime_type("", "image") == "png"
10221022
assert executor._get_file_format_from_mime_type("", "video") == "mp4"
1023+
1024+
1025+
# Tests for enable_a2a_compliant_streaming parameter
1026+
1027+
1028+
@pytest.mark.asyncio
1029+
async def test_legacy_mode_emits_deprecation_warning(mock_strands_agent, mock_request_context, mock_event_queue):
1030+
"""Test that legacy streaming (default) emits deprecation warning."""
1031+
from a2a.types import TextPart
1032+
1033+
executor = StrandsA2AExecutor(mock_strands_agent) # Default is False
1034+
1035+
# Mock stream_async
1036+
async def mock_stream(content_blocks):
1037+
yield {"result": None}
1038+
1039+
mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
1040+
1041+
# Mock task
1042+
mock_task = MagicMock()
1043+
mock_task.id = "test-task-id"
1044+
mock_task.context_id = "test-context-id"
1045+
mock_request_context.current_task = mock_task
1046+
1047+
# Mock message
1048+
mock_text_part = MagicMock(spec=TextPart)
1049+
mock_text_part.text = "test"
1050+
mock_part = MagicMock()
1051+
mock_part.root = mock_text_part
1052+
mock_message = MagicMock()
1053+
mock_message.parts = [mock_part]
1054+
mock_request_context.message = mock_message
1055+
1056+
with pytest.warns(UserWarning, match="does not conform to what is expected in the A2A spec"):
1057+
await executor.execute(mock_request_context, mock_event_queue)
1058+
1059+
1060+
@pytest.mark.asyncio
1061+
async def test_a2a_compliant_mode_no_warning(mock_strands_agent, mock_request_context, mock_event_queue):
1062+
"""Test that A2A-compliant mode does not emit warning."""
1063+
import warnings
1064+
1065+
from a2a.types import TextPart
1066+
1067+
executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True)
1068+
1069+
# Mock stream_async
1070+
async def mock_stream(content_blocks):
1071+
yield {"result": None}
1072+
1073+
mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
1074+
1075+
# Mock task
1076+
mock_task = MagicMock()
1077+
mock_task.id = "test-task-id"
1078+
mock_task.context_id = "test-context-id"
1079+
mock_request_context.current_task = mock_task
1080+
1081+
# Mock message
1082+
mock_text_part = MagicMock(spec=TextPart)
1083+
mock_text_part.text = "test"
1084+
mock_part = MagicMock()
1085+
mock_part.root = mock_text_part
1086+
mock_message = MagicMock()
1087+
mock_message.parts = [mock_part]
1088+
mock_request_context.message = mock_message
1089+
1090+
with warnings.catch_warnings():
1091+
warnings.simplefilter("error")
1092+
try:
1093+
await executor.execute(mock_request_context, mock_event_queue)
1094+
except UserWarning:
1095+
pytest.fail("Should not emit warning")
1096+
1097+
1098+
@pytest.mark.asyncio
1099+
async def test_a2a_compliant_mode_uses_add_artifact(mock_strands_agent):
1100+
"""Test that A2A-compliant mode uses add_artifact with artifact_id."""
1101+
executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True)
1102+
executor._current_artifact_id = "artifact-123"
1103+
executor._is_first_chunk = True
1104+
1105+
mock_updater = MagicMock()
1106+
mock_updater.add_artifact = AsyncMock()
1107+
mock_updater.update_status = AsyncMock()
1108+
1109+
event = {"data": "content"}
1110+
await executor._handle_streaming_event(event, mock_updater)
1111+
1112+
mock_updater.add_artifact.assert_called_once()
1113+
assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-123"
1114+
assert mock_updater.add_artifact.call_args[1]["append"] is False
1115+
mock_updater.update_status.assert_not_called()

0 commit comments

Comments
 (0)