Skip to content

Commit 97f435d

Browse files
committed
Add s3 source for doc, image, video
1 parent e8fc991 commit 97f435d

File tree

14 files changed

+489
-26
lines changed

14 files changed

+489
-26
lines changed

src/strands/models/bedrock.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
489489
if "format" in document:
490490
result["format"] = document["format"]
491491

492-
# Handle source
492+
# Handle source - supports bytes or s3Location
493493
if "source" in document:
494-
result["source"] = {"bytes": document["source"]["bytes"]}
494+
source = document["source"]
495+
if "s3Location" in source:
496+
s3_loc = source["s3Location"]
497+
formatted_document_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
498+
if "bucketOwner" in s3_loc:
499+
formatted_document_s3["bucketOwner"] = s3_loc["bucketOwner"]
500+
result["source"] = {"s3Location": formatted_document_s3}
501+
elif "bytes" in source:
502+
result["source"] = {"bytes": source["bytes"]}
495503

496504
# Handle optional fields
497505
if "citations" in document and document["citations"] is not None:
@@ -512,10 +520,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
512520
if "image" in content:
513521
image = content["image"]
514522
source = image["source"]
515-
formatted_source = {}
516-
if "bytes" in source:
517-
formatted_source = {"bytes": source["bytes"]}
518-
result = {"format": image["format"], "source": formatted_source}
523+
formatted_image_source: dict[str, Any] = {}
524+
if "s3Location" in source:
525+
s3_loc = source["s3Location"]
526+
formatted_image_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
527+
if "bucketOwner" in s3_loc:
528+
formatted_image_s3["bucketOwner"] = s3_loc["bucketOwner"]
529+
formatted_image_source = {"s3Location": formatted_image_s3}
530+
elif "bytes" in source:
531+
formatted_image_source = {"bytes": source["bytes"]}
532+
result = {"format": image["format"], "source": formatted_image_source}
519533
return {"image": result}
520534

521535
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
@@ -577,10 +591,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
577591
if "video" in content:
578592
video = content["video"]
579593
source = video["source"]
580-
formatted_source = {}
581-
if "bytes" in source:
582-
formatted_source = {"bytes": source["bytes"]}
583-
result = {"format": video["format"], "source": formatted_source}
594+
formatted_video_source: dict[str, Any] = {}
595+
if "s3Location" in source:
596+
s3_loc = source["s3Location"]
597+
formatted_video_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
598+
if "bucketOwner" in s3_loc:
599+
formatted_video_s3["bucketOwner"] = s3_loc["bucketOwner"]
600+
formatted_video_source = {"s3Location": formatted_video_s3}
601+
elif "bytes" in source:
602+
formatted_video_source = {"bytes": source["bytes"]}
603+
result = {"format": video["format"], "source": formatted_video_source}
584604
return {"video": result}
585605

586606
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html

src/strands/types/media.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,42 @@
77

88
from typing import Literal
99

10-
from typing_extensions import TypedDict
10+
from typing_extensions import Required, TypedDict
1111

1212
from .citations import CitationsConfig
1313

1414
DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
1515
"""Supported document formats."""
1616

1717

18-
class DocumentSource(TypedDict):
18+
class S3Location(TypedDict, total=False):
19+
"""A storage location in an Amazon S3 bucket.
20+
21+
Used by Bedrock to reference media files stored in S3 instead of passing raw bytes.
22+
23+
- Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html
24+
25+
Attributes:
26+
uri: An object URI starting with `s3://`. Required.
27+
bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional.
28+
"""
29+
30+
uri: Required[str]
31+
bucketOwner: str
32+
33+
34+
class DocumentSource(TypedDict, total=False):
1935
"""Contains the content of a document.
2036
37+
Only one of `bytes` or `s3Location` should be specified.
38+
2139
Attributes:
2240
bytes: The binary content of the document.
41+
s3Location: S3 location of the document (Bedrock only).
2342
"""
2443

2544
bytes: bytes
45+
s3Location: S3Location
2646

2747

2848
class DocumentContent(TypedDict, total=False):
@@ -45,14 +65,18 @@ class DocumentContent(TypedDict, total=False):
4565
"""Supported image formats."""
4666

4767

48-
class ImageSource(TypedDict):
68+
class ImageSource(TypedDict, total=False):
4969
"""Contains the content of an image.
5070
71+
Only one of `bytes` or `s3Location` should be specified.
72+
5173
Attributes:
5274
bytes: The binary content of the image.
75+
s3Location: S3 location of the image (Bedrock only).
5376
"""
5477

5578
bytes: bytes
79+
s3Location: S3Location
5680

5781

5882
class ImageContent(TypedDict):
@@ -71,14 +95,18 @@ class ImageContent(TypedDict):
7195
"""Supported video formats."""
7296

7397

74-
class VideoSource(TypedDict):
98+
class VideoSource(TypedDict, total=False):
7599
"""Contains the content of a video.
76100
101+
Only one of `bytes` or `s3Location` should be specified.
102+
77103
Attributes:
78104
bytes: The binary content of the video.
105+
s3Location: S3 location of the video (Bedrock only).
79106
"""
80107

81108
bytes: bytes
109+
s3Location: S3Location
82110

83111

84112
class VideoContent(TypedDict):
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Tests for model validation helper functions."""
2+
3+
from strands.models._validation import _has_s3_source
4+
5+
6+
class TestHasS3Source:
7+
"""Tests for _has_s3_source helper function."""
8+
9+
def test_image_with_s3_source(self):
10+
"""Test detection of S3 source in image content."""
11+
content = {"image": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}
12+
assert _has_s3_source(content) is True
13+
14+
def test_image_with_bytes_source(self):
15+
"""Test that bytes source is not detected as S3."""
16+
content = {"image": {"source": {"bytes": b"data"}}}
17+
assert _has_s3_source(content) is False
18+
19+
def test_document_with_s3_source(self):
20+
"""Test detection of S3 source in document content."""
21+
content = {"document": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}
22+
assert _has_s3_source(content) is True
23+
24+
def test_document_with_bytes_source(self):
25+
"""Test that bytes source is not detected as S3."""
26+
content = {"document": {"source": {"bytes": b"data"}}}
27+
assert _has_s3_source(content) is False
28+
29+
def test_video_with_s3_source(self):
30+
"""Test detection of S3 source in video content."""
31+
content = {"video": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}
32+
assert _has_s3_source(content) is True
33+
34+
def test_video_with_bytes_source(self):
35+
"""Test that bytes source is not detected as S3."""
36+
content = {"video": {"source": {"bytes": b"data"}}}
37+
assert _has_s3_source(content) is False
38+
39+
def test_text_content(self):
40+
"""Test that text content is not detected as S3 source."""
41+
content = {"text": "hello"}
42+
assert _has_s3_source(content) is False
43+
44+
def test_tool_use_content(self):
45+
"""Test that toolUse content is not detected as S3 source."""
46+
content = {"toolUse": {"name": "test", "input": {}, "toolUseId": "123"}}
47+
assert _has_s3_source(content) is False
48+
49+
def test_tool_result_content(self):
50+
"""Test that toolResult content is not detected as S3 source."""
51+
content = {"toolResult": {"toolUseId": "123", "content": [{"text": "result"}]}}
52+
assert _has_s3_source(content) is False
53+
54+
def test_image_without_source(self):
55+
"""Test that image without source is not detected as S3."""
56+
content = {"image": {"format": "png"}}
57+
assert _has_s3_source(content) is False
58+
59+
def test_document_without_source(self):
60+
"""Test that document without source is not detected as S3."""
61+
content = {"document": {"format": "pdf", "name": "test.pdf"}}
62+
assert _has_s3_source(content) is False
63+
64+
def test_video_without_source(self):
65+
"""Test that video without source is not detected as S3."""
66+
content = {"video": {"format": "mp4"}}
67+
assert _has_s3_source(content) is False

tests/strands/models/test_bedrock.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,8 +1787,8 @@ def test_format_request_filters_image_content_blocks(model, model_id):
17871787
assert "metadata" not in image_block
17881788

17891789

1790-
def test_format_request_filters_nested_image_s3_fields(model, model_id):
1791-
"""Test that s3Location is filtered out and only bytes source is preserved."""
1790+
def test_format_request_image_s3_location_only(model, model_id):
1791+
"""Test that image with only s3Location is properly formatted."""
17921792
messages = [
17931793
{
17941794
"role": "user",
@@ -1797,8 +1797,7 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id):
17971797
"image": {
17981798
"format": "png",
17991799
"source": {
1800-
"bytes": b"image_data",
1801-
"s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"},
1800+
"s3Location": {"uri": "s3://my-bucket/image.png"},
18021801
},
18031802
}
18041803
}
@@ -1809,8 +1808,78 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id):
18091808
formatted_request = model._format_request(messages)
18101809
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
18111810

1811+
assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}}
1812+
1813+
1814+
def test_format_request_image_bytes_only(model, model_id):
1815+
"""Test that image with only bytes source is properly formatted."""
1816+
messages = [
1817+
{
1818+
"role": "user",
1819+
"content": [
1820+
{
1821+
"image": {
1822+
"format": "png",
1823+
"source": {"bytes": b"image_data"},
1824+
}
1825+
}
1826+
],
1827+
}
1828+
]
1829+
1830+
formatted_request = model._format_request(messages)
1831+
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
1832+
18121833
assert image_source == {"bytes": b"image_data"}
1813-
assert "s3Location" not in image_source
1834+
1835+
1836+
def test_format_request_document_s3_location(model, model_id):
1837+
"""Test that document with s3Location is properly formatted."""
1838+
messages = [
1839+
{
1840+
"role": "user",
1841+
"content": [
1842+
{
1843+
"document": {
1844+
"name": "report.pdf",
1845+
"format": "pdf",
1846+
"source": {
1847+
"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"},
1848+
},
1849+
}
1850+
}
1851+
],
1852+
}
1853+
]
1854+
1855+
formatted_request = model._format_request(messages)
1856+
document = formatted_request["messages"][0]["content"][0]["document"]
1857+
1858+
assert document["source"] == {"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"}}
1859+
1860+
1861+
def test_format_request_video_s3_location(model, model_id):
1862+
"""Test that video with s3Location is properly formatted."""
1863+
messages = [
1864+
{
1865+
"role": "user",
1866+
"content": [
1867+
{
1868+
"video": {
1869+
"format": "mp4",
1870+
"source": {
1871+
"s3Location": {"uri": "s3://my-bucket/video.mp4"},
1872+
},
1873+
}
1874+
}
1875+
],
1876+
}
1877+
]
1878+
1879+
formatted_request = model._format_request(messages)
1880+
video_source = formatted_request["messages"][0]["content"][0]["video"]["source"]
1881+
1882+
assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}}
18141883

18151884

18161885
def test_format_request_filters_document_content_blocks(model, model_id):

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock
632632
def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session):
633633
"""EmbeddedResource.resource (blob with image MIME) should map to image content."""
634634
# Read yellow.png file
635-
with open("tests_integ/yellow.png", "rb") as image_file:
635+
with open("tests_integ/resources/yellow.png", "rb") as image_file:
636636
png_data = image_file.read()
637637
payload = base64.b64encode(png_data).decode()
638638

0 commit comments

Comments
 (0)