Skip to content

Commit b34f588

Browse files
committed
Add s3 source for doc, image, video
1 parent e8fc991 commit b34f588

File tree

13 files changed

+422
-26
lines changed

13 files changed

+422
-26
lines changed

src/strands/models/bedrock.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
489489
if "format" in document:
490490
result["format"] = document["format"]
491491

492-
# Handle source
492+
# Handle source - supports bytes or s3Location
493493
if "source" in document:
494-
result["source"] = {"bytes": document["source"]["bytes"]}
494+
source = document["source"]
495+
if "s3Location" in source:
496+
s3_loc = source["s3Location"]
497+
formatted_document_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
498+
if "bucketOwner" in s3_loc:
499+
formatted_document_s3["bucketOwner"] = s3_loc["bucketOwner"]
500+
result["source"] = {"s3Location": formatted_document_s3}
501+
elif "bytes" in source:
502+
result["source"] = {"bytes": source["bytes"]}
495503

496504
# Handle optional fields
497505
if "citations" in document and document["citations"] is not None:
@@ -512,10 +520,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
512520
if "image" in content:
513521
image = content["image"]
514522
source = image["source"]
515-
formatted_source = {}
516-
if "bytes" in source:
517-
formatted_source = {"bytes": source["bytes"]}
518-
result = {"format": image["format"], "source": formatted_source}
523+
formatted_image_source: dict[str, Any] = {}
524+
if "s3Location" in source:
525+
s3_loc = source["s3Location"]
526+
formatted_image_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
527+
if "bucketOwner" in s3_loc:
528+
formatted_image_s3["bucketOwner"] = s3_loc["bucketOwner"]
529+
formatted_image_source = {"s3Location": formatted_image_s3}
530+
elif "bytes" in source:
531+
formatted_image_source = {"bytes": source["bytes"]}
532+
result = {"format": image["format"], "source": formatted_image_source}
519533
return {"image": result}
520534

521535
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
@@ -577,10 +591,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
577591
if "video" in content:
578592
video = content["video"]
579593
source = video["source"]
580-
formatted_source = {}
581-
if "bytes" in source:
582-
formatted_source = {"bytes": source["bytes"]}
583-
result = {"format": video["format"], "source": formatted_source}
594+
formatted_video_source: dict[str, Any] = {}
595+
if "s3Location" in source:
596+
s3_loc = source["s3Location"]
597+
formatted_video_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
598+
if "bucketOwner" in s3_loc:
599+
formatted_video_s3["bucketOwner"] = s3_loc["bucketOwner"]
600+
formatted_video_source = {"s3Location": formatted_video_s3}
601+
elif "bytes" in source:
602+
formatted_video_source = {"bytes": source["bytes"]}
603+
result = {"format": video["format"], "source": formatted_video_source}
584604
return {"video": result}
585605

586606
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html

src/strands/types/media.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,42 @@
77

88
from typing import Literal
99

10-
from typing_extensions import TypedDict
10+
from typing_extensions import Required, TypedDict
1111

1212
from .citations import CitationsConfig
1313

1414
DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
1515
"""Supported document formats."""
1616

1717

18-
class DocumentSource(TypedDict):
18+
class S3Location(TypedDict, total=False):
19+
"""A storage location in an Amazon S3 bucket.
20+
21+
Used by Bedrock to reference media files stored in S3 instead of passing raw bytes.
22+
23+
- Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html
24+
25+
Attributes:
26+
uri: An object URI starting with `s3://`. Required.
27+
bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional.
28+
"""
29+
30+
uri: Required[str]
31+
bucketOwner: str
32+
33+
34+
class DocumentSource(TypedDict, total=False):
1935
"""Contains the content of a document.
2036
37+
Only one of `bytes` or `s3Location` should be specified.
38+
2139
Attributes:
2240
bytes: The binary content of the document.
41+
s3Location: S3 location of the document (Bedrock only).
2342
"""
2443

2544
bytes: bytes
45+
s3Location: S3Location
2646

2747

2848
class DocumentContent(TypedDict, total=False):
@@ -45,14 +65,18 @@ class DocumentContent(TypedDict, total=False):
4565
"""Supported image formats."""
4666

4767

48-
class ImageSource(TypedDict):
68+
class ImageSource(TypedDict, total=False):
4969
"""Contains the content of an image.
5070
71+
Only one of `bytes` or `s3Location` should be specified.
72+
5173
Attributes:
5274
bytes: The binary content of the image.
75+
s3Location: S3 location of the image (Bedrock only).
5376
"""
5477

5578
bytes: bytes
79+
s3Location: S3Location
5680

5781

5882
class ImageContent(TypedDict):
@@ -71,14 +95,18 @@ class ImageContent(TypedDict):
7195
"""Supported video formats."""
7296

7397

74-
class VideoSource(TypedDict):
98+
class VideoSource(TypedDict, total=False):
7599
"""Contains the content of a video.
76100
101+
Only one of `bytes` or `s3Location` should be specified.
102+
77103
Attributes:
78104
bytes: The binary content of the video.
105+
s3Location: S3 location of the video (Bedrock only).
79106
"""
80107

81108
bytes: bytes
109+
s3Location: S3Location
82110

83111

84112
class VideoContent(TypedDict):

tests/strands/models/test_bedrock.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,8 +1787,8 @@ def test_format_request_filters_image_content_blocks(model, model_id):
17871787
assert "metadata" not in image_block
17881788

17891789

1790-
def test_format_request_filters_nested_image_s3_fields(model, model_id):
1791-
"""Test that s3Location is filtered out and only bytes source is preserved."""
1790+
def test_format_request_image_s3_location_only(model, model_id):
1791+
"""Test that image with only s3Location is properly formatted."""
17921792
messages = [
17931793
{
17941794
"role": "user",
@@ -1797,8 +1797,7 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id):
17971797
"image": {
17981798
"format": "png",
17991799
"source": {
1800-
"bytes": b"image_data",
1801-
"s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"},
1800+
"s3Location": {"uri": "s3://my-bucket/image.png"},
18021801
},
18031802
}
18041803
}
@@ -1809,8 +1808,78 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id):
18091808
formatted_request = model._format_request(messages)
18101809
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
18111810

1811+
assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}}
1812+
1813+
1814+
def test_format_request_image_bytes_only(model, model_id):
1815+
"""Test that image with only bytes source is properly formatted."""
1816+
messages = [
1817+
{
1818+
"role": "user",
1819+
"content": [
1820+
{
1821+
"image": {
1822+
"format": "png",
1823+
"source": {"bytes": b"image_data"},
1824+
}
1825+
}
1826+
],
1827+
}
1828+
]
1829+
1830+
formatted_request = model._format_request(messages)
1831+
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
1832+
18121833
assert image_source == {"bytes": b"image_data"}
1813-
assert "s3Location" not in image_source
1834+
1835+
1836+
def test_format_request_document_s3_location(model, model_id):
1837+
"""Test that document with s3Location is properly formatted."""
1838+
messages = [
1839+
{
1840+
"role": "user",
1841+
"content": [
1842+
{
1843+
"document": {
1844+
"name": "report.pdf",
1845+
"format": "pdf",
1846+
"source": {
1847+
"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"},
1848+
},
1849+
}
1850+
}
1851+
],
1852+
}
1853+
]
1854+
1855+
formatted_request = model._format_request(messages)
1856+
document = formatted_request["messages"][0]["content"][0]["document"]
1857+
1858+
assert document["source"] == {"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"}}
1859+
1860+
1861+
def test_format_request_video_s3_location(model, model_id):
1862+
"""Test that video with s3Location is properly formatted."""
1863+
messages = [
1864+
{
1865+
"role": "user",
1866+
"content": [
1867+
{
1868+
"video": {
1869+
"format": "mp4",
1870+
"source": {
1871+
"s3Location": {"uri": "s3://my-bucket/video.mp4"},
1872+
},
1873+
}
1874+
}
1875+
],
1876+
}
1877+
]
1878+
1879+
formatted_request = model._format_request(messages)
1880+
video_source = formatted_request["messages"][0]["content"][0]["video"]["source"]
1881+
1882+
assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}}
18141883

18151884

18161885
def test_format_request_filters_document_content_blocks(model, model_id):

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock
632632
def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session):
633633
"""EmbeddedResource.resource (blob with image MIME) should map to image content."""
634634
# Read yellow.png file
635-
with open("tests_integ/yellow.png", "rb") as image_file:
635+
with open("tests_integ/resources/yellow.png", "rb") as image_file:
636636
png_data = image_file.read()
637637
payload = base64.b64encode(png_data).decode()
638638

tests/strands/types/test_media.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Tests for media type definitions."""
2+
3+
from strands.types.media import (
4+
DocumentSource,
5+
ImageSource,
6+
S3Location,
7+
VideoSource,
8+
)
9+
10+
11+
class TestS3Location:
12+
"""Tests for S3Location TypedDict."""
13+
14+
def test_s3_location_with_uri_only(self):
15+
"""Test S3Location with only uri field."""
16+
s3_loc: S3Location = {"uri": "s3://my-bucket/path/to/file.pdf"}
17+
18+
assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf"
19+
assert "bucketOwner" not in s3_loc
20+
21+
def test_s3_location_with_bucket_owner(self):
22+
"""Test S3Location with both uri and bucketOwner fields."""
23+
s3_loc: S3Location = {
24+
"uri": "s3://my-bucket/path/to/file.pdf",
25+
"bucketOwner": "123456789012",
26+
}
27+
28+
assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf"
29+
assert s3_loc["bucketOwner"] == "123456789012"
30+
31+
32+
class TestDocumentSource:
33+
"""Tests for DocumentSource TypedDict."""
34+
35+
def test_document_source_with_bytes(self):
36+
"""Test DocumentSource with bytes content."""
37+
doc_source: DocumentSource = {"bytes": b"document content"}
38+
39+
assert doc_source["bytes"] == b"document content"
40+
assert "s3Location" not in doc_source
41+
42+
def test_document_source_with_s3_location(self):
43+
"""Test DocumentSource with s3Location."""
44+
doc_source: DocumentSource = {
45+
"s3Location": {
46+
"uri": "s3://my-bucket/docs/report.pdf",
47+
"bucketOwner": "123456789012",
48+
}
49+
}
50+
51+
assert "bytes" not in doc_source
52+
assert doc_source["s3Location"]["uri"] == "s3://my-bucket/docs/report.pdf"
53+
assert doc_source["s3Location"]["bucketOwner"] == "123456789012"
54+
55+
56+
class TestImageSource:
57+
"""Tests for ImageSource TypedDict."""
58+
59+
def test_image_source_with_bytes(self):
60+
"""Test ImageSource with bytes content."""
61+
img_source: ImageSource = {"bytes": b"image content"}
62+
63+
assert img_source["bytes"] == b"image content"
64+
assert "s3Location" not in img_source
65+
66+
def test_image_source_with_s3_location(self):
67+
"""Test ImageSource with s3Location."""
68+
img_source: ImageSource = {
69+
"s3Location": {
70+
"uri": "s3://my-bucket/images/photo.png",
71+
}
72+
}
73+
74+
assert "bytes" not in img_source
75+
assert img_source["s3Location"]["uri"] == "s3://my-bucket/images/photo.png"
76+
77+
78+
class TestVideoSource:
79+
"""Tests for VideoSource TypedDict."""
80+
81+
def test_video_source_with_bytes(self):
82+
"""Test VideoSource with bytes content."""
83+
vid_source: VideoSource = {"bytes": b"video content"}
84+
85+
assert vid_source["bytes"] == b"video content"
86+
assert "s3Location" not in vid_source
87+
88+
def test_video_source_with_s3_location(self):
89+
"""Test VideoSource with s3Location."""
90+
vid_source: VideoSource = {
91+
"s3Location": {
92+
"uri": "s3://my-bucket/videos/clip.mp4",
93+
"bucketOwner": "987654321098",
94+
}
95+
}
96+
97+
assert "bytes" not in vid_source
98+
assert vid_source["s3Location"]["uri"] == "s3://my-bucket/videos/clip.mp4"
99+
assert vid_source["s3Location"]["bucketOwner"] == "987654321098"

tests_integ/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,21 @@ def pytest_sessionstart(session):
133133

134134
@pytest.fixture
135135
def yellow_img(pytestconfig):
136-
path = pytestconfig.rootdir / "tests_integ/yellow.png"
136+
path = pytestconfig.rootdir / "tests_integ/resources/yellow.png"
137137
with open(path, "rb") as fp:
138138
return fp.read()
139139

140140

141141
@pytest.fixture
142142
def letter_pdf(pytestconfig):
143-
path = pytestconfig.rootdir / "tests_integ/letter.pdf"
143+
path = pytestconfig.rootdir / "tests_integ/resources/letter.pdf"
144+
with open(path, "rb") as fp:
145+
return fp.read()
146+
147+
148+
@pytest.fixture
149+
def blue_video(pytestconfig):
150+
path = pytestconfig.rootdir / "tests_integ/resources/blue.mp4"
144151
with open(path, "rb") as fp:
145152
return fp.read()
146153

tests_integ/mcp/echo_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"):
9090
]
9191
elif location.lower() == "tokyo":
9292
# Read yellow.png file for weather icon
93-
with open("tests_integ/yellow.png", "rb") as image_file:
93+
with open("tests_integ/resources/yellow.png", "rb") as image_file:
9494
png_data = image_file.read()
9595
return [
9696
EmbeddedResource(

tests_integ/mcp/test_mcp_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def calculator(x: int, y: int) -> int:
4343
@mcp.tool(description="Generates a custom image")
4444
def generate_custom_image() -> MCPImageContent:
4545
try:
46-
with open("tests_integ/yellow.png", "rb") as image_file:
46+
with open("tests_integ/resources/yellow.png", "rb") as image_file:
4747
encoded_image = base64.b64encode(image_file.read())
4848
return MCPImageContent(type="image", data=encoded_image, mimeType="image/png")
4949
except Exception as e:

tests_integ/resources/blue.mp4

5.08 KB
Binary file not shown.

0 commit comments

Comments
 (0)