Skip to content

Commit 1b60ecc

Browse files
committed
Add s3 source for doc, image, video
1 parent e8fc991 commit 1b60ecc

File tree

14 files changed

+457
-27
lines changed

14 files changed

+457
-27
lines changed

src/strands/models/bedrock.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
489489
if "format" in document:
490490
result["format"] = document["format"]
491491

492-
# Handle source
492+
# Handle source - supports bytes or s3Location
493493
if "source" in document:
494-
result["source"] = {"bytes": document["source"]["bytes"]}
494+
source = document["source"]
495+
if "s3Location" in source:
496+
s3_loc = source["s3Location"]
497+
formatted_document_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
498+
if "bucketOwner" in s3_loc:
499+
formatted_document_s3["bucketOwner"] = s3_loc["bucketOwner"]
500+
result["source"] = {"s3Location": formatted_document_s3}
501+
elif "bytes" in source:
502+
result["source"] = {"bytes": source["bytes"]}
495503

496504
# Handle optional fields
497505
if "citations" in document and document["citations"] is not None:
@@ -512,10 +520,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
512520
if "image" in content:
513521
image = content["image"]
514522
source = image["source"]
515-
formatted_source = {}
516-
if "bytes" in source:
517-
formatted_source = {"bytes": source["bytes"]}
518-
result = {"format": image["format"], "source": formatted_source}
523+
formatted_image_source: dict[str, Any] = {}
524+
if "s3Location" in source:
525+
s3_loc = source["s3Location"]
526+
formatted_image_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
527+
if "bucketOwner" in s3_loc:
528+
formatted_image_s3["bucketOwner"] = s3_loc["bucketOwner"]
529+
formatted_image_source = {"s3Location": formatted_image_s3}
530+
elif "bytes" in source:
531+
formatted_image_source = {"bytes": source["bytes"]}
532+
result = {"format": image["format"], "source": formatted_image_source}
519533
return {"image": result}
520534

521535
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
@@ -577,10 +591,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
577591
if "video" in content:
578592
video = content["video"]
579593
source = video["source"]
580-
formatted_source = {}
581-
if "bytes" in source:
582-
formatted_source = {"bytes": source["bytes"]}
583-
result = {"format": video["format"], "source": formatted_source}
594+
formatted_video_source: dict[str, Any] = {}
595+
if "s3Location" in source:
596+
s3_loc = source["s3Location"]
597+
formatted_video_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
598+
if "bucketOwner" in s3_loc:
599+
formatted_video_s3["bucketOwner"] = s3_loc["bucketOwner"]
600+
formatted_video_source = {"s3Location": formatted_video_s3}
601+
elif "bytes" in source:
602+
formatted_video_source = {"bytes": source["bytes"]}
603+
result = {"format": video["format"], "source": formatted_video_source}
584604
return {"video": result}
585605

586606
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html

src/strands/types/media.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,42 @@
77

88
from typing import Literal
99

10-
from typing_extensions import TypedDict
10+
from typing_extensions import Required, TypedDict
1111

1212
from .citations import CitationsConfig
1313

1414
DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
1515
"""Supported document formats."""
1616

1717

18-
class DocumentSource(TypedDict):
18+
class S3Location(TypedDict, total=False):
19+
"""A storage location in an Amazon S3 bucket.
20+
21+
Used by Bedrock to reference media files stored in S3 instead of passing raw bytes.
22+
23+
- Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html
24+
25+
Attributes:
26+
uri: An object URI starting with `s3://`. Required.
27+
bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional.
28+
"""
29+
30+
uri: Required[str]
31+
bucketOwner: str
32+
33+
34+
class DocumentSource(TypedDict, total=False):
1935
"""Contains the content of a document.
2036
37+
Only one of `bytes` or `s3Location` should be specified.
38+
2139
Attributes:
2240
bytes: The binary content of the document.
41+
s3Location: S3 location of the document (Bedrock only).
2342
"""
2443

2544
bytes: bytes
45+
s3Location: S3Location
2646

2747

2848
class DocumentContent(TypedDict, total=False):
@@ -45,14 +65,18 @@ class DocumentContent(TypedDict, total=False):
4565
"""Supported image formats."""
4666

4767

48-
class ImageSource(TypedDict):
68+
class ImageSource(TypedDict, total=False):
4969
"""Contains the content of an image.
5070
71+
Only one of `bytes` or `s3Location` should be specified.
72+
5173
Attributes:
5274
bytes: The binary content of the image.
75+
s3Location: S3 location of the image (Bedrock only).
5376
"""
5477

5578
bytes: bytes
79+
s3Location: S3Location
5680

5781

5882
class ImageContent(TypedDict):
@@ -71,14 +95,18 @@ class ImageContent(TypedDict):
7195
"""Supported video formats."""
7296

7397

74-
class VideoSource(TypedDict):
98+
class VideoSource(TypedDict, total=False):
7599
"""Contains the content of a video.
76100
101+
Only one of `bytes` or `s3Location` should be specified.
102+
77103
Attributes:
78104
bytes: The binary content of the video.
105+
s3Location: S3 location of the video (Bedrock only).
79106
"""
80107

81108
bytes: bytes
109+
s3Location: S3Location
82110

83111

84112
class VideoContent(TypedDict):

tests/strands/agent/test_agent_result.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,3 @@ def test__str__empty_interrupts_returns_agent_message(mock_metrics, simple_messa
370370

371371
# Empty list is falsy, should fall through to text content
372372
assert message_string == "Hello world!\n"
373-

tests/strands/models/test_bedrock.py

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,8 +1787,8 @@ def test_format_request_filters_image_content_blocks(model, model_id):
17871787
assert "metadata" not in image_block
17881788

17891789

1790-
def test_format_request_filters_nested_image_s3_fields(model, model_id):
1791-
"""Test that s3Location is filtered out and only bytes source is preserved."""
1790+
def test_format_request_image_s3_location_only(model, model_id):
1791+
"""Test that image with only s3Location is properly formatted."""
17921792
messages = [
17931793
{
17941794
"role": "user",
@@ -1797,10 +1797,41 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id):
17971797
"image": {
17981798
"format": "png",
17991799
"source": {
1800-
"bytes": b"image_data",
1801-
"s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"},
1800+
"s3Location": {"uri": "s3://my-bucket/image.png"},
18021801
},
18031802
}
1803+
},
1804+
{
1805+
"image": {
1806+
"format": "png",
1807+
"source": {
1808+
"s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "12345"},
1809+
},
1810+
}
1811+
},
1812+
],
1813+
}
1814+
]
1815+
1816+
formatted_request = model._format_request(messages)
1817+
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
1818+
image_source_with_bucket_owner = formatted_request["messages"][0]["content"][1]["image"]["source"]
1819+
1820+
assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}}
1821+
assert image_source_with_bucket_owner == {"s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "12345"}}
1822+
1823+
1824+
def test_format_request_image_bytes_only(model, model_id):
1825+
"""Test that image with only bytes source is properly formatted."""
1826+
messages = [
1827+
{
1828+
"role": "user",
1829+
"content": [
1830+
{
1831+
"image": {
1832+
"format": "png",
1833+
"source": {"bytes": b"image_data"},
1834+
}
18041835
}
18051836
],
18061837
}
@@ -1810,7 +1841,79 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id):
18101841
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
18111842

18121843
assert image_source == {"bytes": b"image_data"}
1813-
assert "s3Location" not in image_source
1844+
1845+
1846+
def test_format_request_document_s3_location(model, model_id):
1847+
"""Test that document with s3Location is properly formatted."""
1848+
messages = [
1849+
{
1850+
"role": "user",
1851+
"content": [
1852+
{
1853+
"document": {
1854+
"name": "report.pdf",
1855+
"format": "pdf",
1856+
"source": {
1857+
"s3Location": {"uri": "s3://my-bucket/report.pdf"},
1858+
},
1859+
}
1860+
},
1861+
{
1862+
"document": {
1863+
"name": "report.pdf",
1864+
"format": "pdf",
1865+
"source": {
1866+
"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"},
1867+
},
1868+
}
1869+
},
1870+
],
1871+
}
1872+
]
1873+
1874+
formatted_request = model._format_request(messages)
1875+
document = formatted_request["messages"][0]["content"][0]["document"]
1876+
document_with_bucket_owner = formatted_request["messages"][0]["content"][1]["document"]
1877+
1878+
assert document["source"] == {"s3Location": {"uri": "s3://my-bucket/report.pdf"}}
1879+
1880+
assert document_with_bucket_owner["source"] == {
1881+
"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"}
1882+
}
1883+
1884+
1885+
def test_format_request_video_s3_location(model, model_id):
1886+
"""Test that video with s3Location is properly formatted."""
1887+
messages = [
1888+
{
1889+
"role": "user",
1890+
"content": [
1891+
{
1892+
"video": {
1893+
"format": "mp4",
1894+
"source": {
1895+
"s3Location": {"uri": "s3://my-bucket/video.mp4"},
1896+
},
1897+
}
1898+
},
1899+
{
1900+
"video": {
1901+
"format": "mp4",
1902+
"source": {
1903+
"s3Location": {"uri": "s3://my-bucket/video.mp4", "bucketOwner": "12345"},
1904+
},
1905+
}
1906+
},
1907+
],
1908+
}
1909+
]
1910+
1911+
formatted_request = model._format_request(messages)
1912+
video_source = formatted_request["messages"][0]["content"][0]["video"]["source"]
1913+
video_source_with_bucket_owner = formatted_request["messages"][0]["content"][1]["video"]["source"]
1914+
1915+
assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}}
1916+
assert video_source_with_bucket_owner == {"s3Location": {"uri": "s3://my-bucket/video.mp4", "bucketOwner": "12345"}}
18141917

18151918

18161919
def test_format_request_filters_document_content_blocks(model, model_id):

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock
632632
def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session):
633633
"""EmbeddedResource.resource (blob with image MIME) should map to image content."""
634634
# Read yellow.png file
635-
with open("tests_integ/yellow.png", "rb") as image_file:
635+
with open("tests_integ/resources/yellow.png", "rb") as image_file:
636636
png_data = image_file.read()
637637
payload = base64.b64encode(png_data).decode()
638638

tests/strands/types/test_media.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Tests for media type definitions."""
2+
3+
from strands.types.media import (
4+
DocumentSource,
5+
ImageSource,
6+
S3Location,
7+
VideoSource,
8+
)
9+
10+
11+
class TestS3Location:
12+
"""Tests for S3Location TypedDict."""
13+
14+
def test_s3_location_with_uri_only(self):
15+
"""Test S3Location with only uri field."""
16+
s3_loc: S3Location = {"uri": "s3://my-bucket/path/to/file.pdf"}
17+
18+
assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf"
19+
assert "bucketOwner" not in s3_loc
20+
21+
def test_s3_location_with_bucket_owner(self):
22+
"""Test S3Location with both uri and bucketOwner fields."""
23+
s3_loc: S3Location = {
24+
"uri": "s3://my-bucket/path/to/file.pdf",
25+
"bucketOwner": "123456789012",
26+
}
27+
28+
assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf"
29+
assert s3_loc["bucketOwner"] == "123456789012"
30+
31+
32+
class TestDocumentSource:
33+
"""Tests for DocumentSource TypedDict."""
34+
35+
def test_document_source_with_bytes(self):
36+
"""Test DocumentSource with bytes content."""
37+
doc_source: DocumentSource = {"bytes": b"document content"}
38+
39+
assert doc_source["bytes"] == b"document content"
40+
assert "s3Location" not in doc_source
41+
42+
def test_document_source_with_s3_location(self):
43+
"""Test DocumentSource with s3Location."""
44+
doc_source: DocumentSource = {
45+
"s3Location": {
46+
"uri": "s3://my-bucket/docs/report.pdf",
47+
"bucketOwner": "123456789012",
48+
}
49+
}
50+
51+
assert "bytes" not in doc_source
52+
assert doc_source["s3Location"]["uri"] == "s3://my-bucket/docs/report.pdf"
53+
assert doc_source["s3Location"]["bucketOwner"] == "123456789012"
54+
55+
56+
class TestImageSource:
57+
"""Tests for ImageSource TypedDict."""
58+
59+
def test_image_source_with_bytes(self):
60+
"""Test ImageSource with bytes content."""
61+
img_source: ImageSource = {"bytes": b"image content"}
62+
63+
assert img_source["bytes"] == b"image content"
64+
assert "s3Location" not in img_source
65+
66+
def test_image_source_with_s3_location(self):
67+
"""Test ImageSource with s3Location."""
68+
img_source: ImageSource = {
69+
"s3Location": {
70+
"uri": "s3://my-bucket/images/photo.png",
71+
}
72+
}
73+
74+
assert "bytes" not in img_source
75+
assert img_source["s3Location"]["uri"] == "s3://my-bucket/images/photo.png"
76+
77+
78+
class TestVideoSource:
79+
"""Tests for VideoSource TypedDict."""
80+
81+
def test_video_source_with_bytes(self):
82+
"""Test VideoSource with bytes content."""
83+
vid_source: VideoSource = {"bytes": b"video content"}
84+
85+
assert vid_source["bytes"] == b"video content"
86+
assert "s3Location" not in vid_source
87+
88+
def test_video_source_with_s3_location(self):
89+
"""Test VideoSource with s3Location."""
90+
vid_source: VideoSource = {
91+
"s3Location": {
92+
"uri": "s3://my-bucket/videos/clip.mp4",
93+
"bucketOwner": "987654321098",
94+
}
95+
}
96+
97+
assert "bytes" not in vid_source
98+
assert vid_source["s3Location"]["uri"] == "s3://my-bucket/videos/clip.mp4"
99+
assert vid_source["s3Location"]["bucketOwner"] == "987654321098"

0 commit comments

Comments
 (0)