Skip to content

Commit ef18a25

Browse files
jer96jer
andauthored
feat(a2a): support A2A FileParts and DataParts (#596)
Co-authored-by: jer <[email protected]>
1 parent 17ccdd2 commit ef18a25

File tree

2 files changed

+947
-25
lines changed

2 files changed

+947
-25
lines changed

src/strands/multiagent/a2a/executor.py

Lines changed: 180 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,29 @@
88
streamed requests to the A2AServer.
99
"""
1010

11+
import json
1112
import logging
12-
from typing import Any
13+
import mimetypes
14+
from typing import Any, Literal
1315

1416
from a2a.server.agent_execution import AgentExecutor, RequestContext
1517
from a2a.server.events import EventQueue
1618
from a2a.server.tasks import TaskUpdater
17-
from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError
19+
from a2a.types import DataPart, FilePart, InternalError, Part, TaskState, TextPart, UnsupportedOperationError
1820
from a2a.utils import new_agent_text_message, new_task
1921
from a2a.utils.errors import ServerError
2022

2123
from ...agent.agent import Agent as SAAgent
2224
from ...agent.agent import AgentResult as SAAgentResult
25+
from ...types.content import ContentBlock
26+
from ...types.media import (
27+
DocumentContent,
28+
DocumentSource,
29+
ImageContent,
30+
ImageSource,
31+
VideoContent,
32+
VideoSource,
33+
)
2334

2435
logger = logging.getLogger(__name__)
2536

@@ -31,6 +42,12 @@ class StrandsA2AExecutor(AgentExecutor):
3142
and converts Strands Agent responses to A2A protocol events.
3243
"""
3344

45+
# Default formats for each file type when MIME type is unavailable or unrecognized
46+
DEFAULT_FORMATS = {"document": "txt", "image": "png", "video": "mp4", "unknown": "txt"}
47+
48+
# Handle special cases where format differs from extension
49+
FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"}
50+
3451
def __init__(self, agent: SAAgent):
3552
"""Initialize a StrandsA2AExecutor.
3653
@@ -78,10 +95,16 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
7895
context: The A2A request context, containing the user's input and other metadata.
7996
updater: The task updater for managing task state and sending updates.
8097
"""
81-
logger.info("Executing request in streaming mode")
82-
user_input = context.get_user_input()
98+
# Convert A2A message parts to Strands ContentBlocks
99+
if context.message and hasattr(context.message, "parts"):
100+
content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts)
101+
if not content_blocks:
102+
raise ValueError("No content blocks available")
103+
else:
104+
raise ValueError("No content blocks available")
105+
83106
try:
84-
async for event in self.agent.stream_async(user_input):
107+
async for event in self.agent.stream_async(content_blocks):
85108
await self._handle_streaming_event(event, updater)
86109
except Exception:
87110
logger.exception("Error in streaming execution")
@@ -146,3 +169,155 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None
146169
"""
147170
logger.warning("Cancellation requested but not supported")
148171
raise ServerError(error=UnsupportedOperationError())
172+
173+
def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]:
174+
"""Classify file type based on MIME type.
175+
176+
Args:
177+
mime_type: The MIME type of the file
178+
179+
Returns:
180+
The classified file type
181+
"""
182+
if not mime_type:
183+
return "unknown"
184+
185+
mime_type = mime_type.lower()
186+
187+
if mime_type.startswith("image/"):
188+
return "image"
189+
elif mime_type.startswith("video/"):
190+
return "video"
191+
elif (
192+
mime_type.startswith("text/")
193+
or mime_type.startswith("application/")
194+
or mime_type in ["application/pdf", "application/json", "application/xml"]
195+
):
196+
return "document"
197+
else:
198+
return "unknown"
199+
200+
def _get_file_format_from_mime_type(self, mime_type: str | None, file_type: str) -> str:
201+
"""Extract file format from MIME type using Python's mimetypes library.
202+
203+
Args:
204+
mime_type: The MIME type of the file
205+
file_type: The classified file type (image, video, document, txt)
206+
207+
Returns:
208+
The file format string
209+
"""
210+
if not mime_type:
211+
return self.DEFAULT_FORMATS.get(file_type, "txt")
212+
213+
mime_type = mime_type.lower()
214+
215+
# Extract subtype from MIME type and check existing format mappings
216+
if "/" in mime_type:
217+
subtype = mime_type.split("/")[-1]
218+
if subtype in self.FORMAT_MAPPINGS:
219+
return self.FORMAT_MAPPINGS[subtype]
220+
221+
# Use mimetypes library to find extensions for the MIME type
222+
extensions = mimetypes.guess_all_extensions(mime_type)
223+
224+
if extensions:
225+
extension = extensions[0][1:] # Remove the leading dot
226+
return self.FORMAT_MAPPINGS.get(extension, extension)
227+
228+
# Fallback to defaults for unknown MIME types
229+
return self.DEFAULT_FORMATS.get(file_type, "txt")
230+
231+
def _strip_file_extension(self, file_name: str) -> str:
232+
"""Strip the file extension from a file name.
233+
234+
Args:
235+
file_name: The original file name with extension
236+
237+
Returns:
238+
The file name without extension
239+
"""
240+
if "." in file_name:
241+
return file_name.rsplit(".", 1)[0]
242+
return file_name
243+
244+
def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[ContentBlock]:
245+
"""Convert A2A message parts to Strands ContentBlocks.
246+
247+
Args:
248+
parts: List of A2A Part objects
249+
250+
Returns:
251+
List of Strands ContentBlock objects
252+
"""
253+
content_blocks: list[ContentBlock] = []
254+
255+
for part in parts:
256+
try:
257+
part_root = part.root
258+
259+
if isinstance(part_root, TextPart):
260+
# Handle TextPart
261+
content_blocks.append(ContentBlock(text=part_root.text))
262+
263+
elif isinstance(part_root, FilePart):
264+
# Handle FilePart
265+
file_obj = part_root.file
266+
mime_type = getattr(file_obj, "mime_type", None)
267+
raw_file_name = getattr(file_obj, "name", "FileNameNotProvided")
268+
file_name = self._strip_file_extension(raw_file_name)
269+
file_type = self._get_file_type_from_mime_type(mime_type)
270+
file_format = self._get_file_format_from_mime_type(mime_type, file_type)
271+
272+
# Handle FileWithBytes vs FileWithUri
273+
bytes_data = getattr(file_obj, "bytes", None)
274+
uri_data = getattr(file_obj, "uri", None)
275+
276+
if bytes_data:
277+
if file_type == "image":
278+
content_blocks.append(
279+
ContentBlock(
280+
image=ImageContent(
281+
format=file_format, # type: ignore
282+
source=ImageSource(bytes=bytes_data),
283+
)
284+
)
285+
)
286+
elif file_type == "video":
287+
content_blocks.append(
288+
ContentBlock(
289+
video=VideoContent(
290+
format=file_format, # type: ignore
291+
source=VideoSource(bytes=bytes_data),
292+
)
293+
)
294+
)
295+
else: # document or unknown
296+
content_blocks.append(
297+
ContentBlock(
298+
document=DocumentContent(
299+
format=file_format, # type: ignore
300+
name=file_name,
301+
source=DocumentSource(bytes=bytes_data),
302+
)
303+
)
304+
)
305+
# Handle FileWithUri
306+
elif uri_data:
307+
# For URI files, create a text representation since Strands ContentBlocks expect bytes
308+
content_blocks.append(
309+
ContentBlock(
310+
text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data)
311+
)
312+
)
313+
elif isinstance(part_root, DataPart):
314+
# Handle DataPart - convert structured data to JSON text
315+
try:
316+
data_text = json.dumps(part_root.data, indent=2)
317+
content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text))
318+
except Exception:
319+
logger.exception("Failed to serialize data part")
320+
except Exception:
321+
logger.exception("Error processing part")
322+
323+
return content_blocks

0 commit comments

Comments
 (0)