Skip to content

Commit e85b6eb

Browse files
committed
Refactor to make location generic
1 parent 05840e4 commit e85b6eb

File tree

4 files changed

+49
-52
lines changed

4 files changed

+49
-52
lines changed

src/strands/models/bedrock.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from pydantic import BaseModel
1818
from typing_extensions import TypedDict, Unpack, override
1919

20+
from strands.types.media import Location, S3Location
21+
2022
from .._exception_notes import add_exception_note
2123
from ..event_loop import streaming
2224
from ..tools import convert_pydantic_to_tool_spec
@@ -459,6 +461,18 @@ def _should_include_tool_result_status(self) -> bool:
459461
else: # "auto"
460462
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)
461463

464+
def _handle_location(self, location: Location) -> dict[str, Any]:
465+
"""Convert location content block to Bedrock format if its an S3Location."""
466+
if location["type"] == "s3":
467+
s3_location = cast(S3Location, location)
468+
formatted_document_s3: dict[str, Any] = {"uri": s3_location["uri"]}
469+
if "bucketOwner" in s3_location:
470+
formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"]
471+
return {"s3Location": formatted_document_s3}
472+
else:
473+
logger.warning("Non s3 location sources are not supported by Bedrock, skipping content block")
474+
return {}
475+
462476
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
463477
"""Format a Bedrock content block.
464478
@@ -489,15 +503,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
489503
if "format" in document:
490504
result["format"] = document["format"]
491505

492-
# Handle source - supports bytes or s3Location
506+
# Handle source - supports bytes or location
493507
if "source" in document:
494508
source = document["source"]
495-
if "s3Location" in source:
496-
s3_loc = source["s3Location"]
497-
formatted_document_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
498-
if "bucketOwner" in s3_loc:
499-
formatted_document_s3["bucketOwner"] = s3_loc["bucketOwner"]
500-
result["source"] = {"s3Location": formatted_document_s3}
509+
if "location" in source:
510+
result["source"] = self._handle_location(source["location"])
501511
elif "bytes" in source:
502512
result["source"] = {"bytes": source["bytes"]}
503513

@@ -521,12 +531,8 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
521531
image = content["image"]
522532
source = image["source"]
523533
formatted_image_source: dict[str, Any] = {}
524-
if "s3Location" in source:
525-
s3_loc = source["s3Location"]
526-
formatted_image_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
527-
if "bucketOwner" in s3_loc:
528-
formatted_image_s3["bucketOwner"] = s3_loc["bucketOwner"]
529-
formatted_image_source = {"s3Location": formatted_image_s3}
534+
if "location" in source:
535+
formatted_image_source = self._handle_location(source["location"])
530536
elif "bytes" in source:
531537
formatted_image_source = {"bytes": source["bytes"]}
532538
result = {"format": image["format"], "source": formatted_image_source}
@@ -592,12 +598,8 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
592598
video = content["video"]
593599
source = video["source"]
594600
formatted_video_source: dict[str, Any] = {}
595-
if "s3Location" in source:
596-
s3_loc = source["s3Location"]
597-
formatted_video_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
598-
if "bucketOwner" in s3_loc:
599-
formatted_video_s3["bucketOwner"] = s3_loc["bucketOwner"]
600-
formatted_video_source = {"s3Location": formatted_video_s3}
601+
if "location" in source:
602+
formatted_video_source = self._handle_location(source["location"])
601603
elif "bytes" in source:
602604
formatted_video_source = {"bytes": source["bytes"]}
603605
result = {"format": video["format"], "source": formatted_video_source}

src/strands/types/media.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,29 @@
1515
"""Supported document formats."""
1616

1717

18-
class S3Location(TypedDict, total=False):
18+
class Location(TypedDict, total=False):
19+
"""A location for a document.
20+
21+
This type is a generic location for a document. Its usage is determined by the underlying model provider
22+
"""
23+
24+
type: Required[str]
25+
26+
27+
class S3Location(Location, total=False):
1928
"""A storage location in an Amazon S3 bucket.
2029
2130
Used by Bedrock to reference media files stored in S3 instead of passing raw bytes.
2231
2332
- Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html
2433
2534
Attributes:
35+
type: s3
2636
uri: An object URI starting with `s3://`. Required.
2737
bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional.
2838
"""
2939

40+
type: Literal["s3"]
3041
uri: Required[str]
3142
bucketOwner: str
3243

@@ -38,11 +49,11 @@ class DocumentSource(TypedDict, total=False):
3849
3950
Attributes:
4051
bytes: The binary content of the document.
41-
s3Location: S3 location of the document (Bedrock only).
52+
location: Location of the document.
4253
"""
4354

4455
bytes: bytes
45-
s3Location: S3Location
56+
location: Location | S3Location
4657

4758

4859
class DocumentContent(TypedDict, total=False):

tests/strands/models/test_bedrock.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,28 +1797,18 @@ def test_format_request_image_s3_location_only(model, model_id):
17971797
"image": {
17981798
"format": "png",
17991799
"source": {
1800-
"s3Location": {"uri": "s3://my-bucket/image.png"},
1800+
"location": {"type": "s3", "uri": "s3://my-bucket/image.png"},
18011801
},
18021802
}
1803-
},
1804-
{
1805-
"image": {
1806-
"format": "png",
1807-
"source": {
1808-
"s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "12345"},
1809-
},
1810-
}
1811-
},
1803+
}
18121804
],
18131805
}
18141806
]
18151807

18161808
formatted_request = model._format_request(messages)
18171809
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
1818-
image_source_with_bucket_owner = formatted_request["messages"][0]["content"][1]["image"]["source"]
18191810

18201811
assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}}
1821-
assert image_source_with_bucket_owner == {"s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "12345"}}
18221812

18231813

18241814
def test_format_request_image_bytes_only(model, model_id):
@@ -1854,7 +1844,7 @@ def test_format_request_document_s3_location(model, model_id):
18541844
"name": "report.pdf",
18551845
"format": "pdf",
18561846
"source": {
1857-
"s3Location": {"uri": "s3://my-bucket/report.pdf"},
1847+
"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"},
18581848
},
18591849
}
18601850
},
@@ -1863,7 +1853,11 @@ def test_format_request_document_s3_location(model, model_id):
18631853
"name": "report.pdf",
18641854
"format": "pdf",
18651855
"source": {
1866-
"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"},
1856+
"location": {
1857+
"type": "s3",
1858+
"uri": "s3://my-bucket/report.pdf",
1859+
"bucketOwner": "123456789012",
1860+
},
18671861
},
18681862
}
18691863
},
@@ -1892,15 +1886,7 @@ def test_format_request_video_s3_location(model, model_id):
18921886
"video": {
18931887
"format": "mp4",
18941888
"source": {
1895-
"s3Location": {"uri": "s3://my-bucket/video.mp4"},
1896-
},
1897-
}
1898-
},
1899-
{
1900-
"video": {
1901-
"format": "mp4",
1902-
"source": {
1903-
"s3Location": {"uri": "s3://my-bucket/video.mp4", "bucketOwner": "12345"},
1889+
"location": {"type": "s3", "uri": "s3://my-bucket/video.mp4"},
19041890
},
19051891
}
19061892
},
@@ -1910,10 +1896,8 @@ def test_format_request_video_s3_location(model, model_id):
19101896

19111897
formatted_request = model._format_request(messages)
19121898
video_source = formatted_request["messages"][0]["content"][0]["video"]["source"]
1913-
video_source_with_bucket_owner = formatted_request["messages"][0]["content"][1]["video"]["source"]
19141899

19151900
assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}}
1916-
assert video_source_with_bucket_owner == {"s3Location": {"uri": "s3://my-bucket/video.mp4", "bucketOwner": "12345"}}
19171901

19181902

19191903
def test_format_request_filters_document_content_blocks(model, model_id):

tests_integ/test_bedrock_s3_location.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ def test_document_s3_location(s3_document, account_id):
120120
"document": {
121121
"format": "pdf",
122122
"name": "letter",
123-
"source": {"s3Location": {"uri": s3_document, "bucketOwner": account_id}},
123+
"source": {"location": {"type": "s3", "uri": s3_document, "bucketOwner": account_id}},
124124
},
125125
},
126126
],
127127
},
128128
]
129129

130-
agent = Agent(model=BedrockModel(model_id="amazon.nova-2-lite-v1:0", region_name="us-west-2"))
130+
agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2"))
131131
result = agent(messages)
132132

133133
assert "amazon" in str(result).lower()
@@ -143,14 +143,14 @@ def test_image_s3_location(s3_image):
143143
{
144144
"image": {
145145
"format": "png",
146-
"source": {"s3Location": {"uri": s3_image}},
146+
"source": {"location": {"type": "s3", "uri": s3_image}},
147147
},
148148
},
149149
],
150150
},
151151
]
152152

153-
agent = Agent(model=BedrockModel(model_id="amazon.nova-2-lite-v1:0", region_name="us-west-2"))
153+
agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2"))
154154
result = agent(messages)
155155

156156
assert "yellow" in str(result).lower()
@@ -163,7 +163,7 @@ def test_video_s3_location(s3_video):
163163
"role": "user",
164164
"content": [
165165
{"text": "Describe the colors is in this video?"},
166-
{"video": {"format": "mp4", "source": {"s3Location": {"uri": s3_video}}}},
166+
{"video": {"format": "mp4", "source": {"location": {"type": "s3", "uri": s3_video}}}},
167167
],
168168
},
169169
]

0 commit comments

Comments
 (0)