Skip to content

Commit 76b1b0b

Browse files
committed
Refactor to make location generic
1 parent 05840e4 commit 76b1b0b

File tree

5 files changed

+125
-62
lines changed

5 files changed

+125
-62
lines changed

src/strands/models/bedrock.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from pydantic import BaseModel
1818
from typing_extensions import TypedDict, Unpack, override
1919

20+
from strands.types.media import S3Location, SourceLocation
21+
2022
from .._exception_notes import add_exception_note
2123
from ..event_loop import streaming
2224
from ..tools import convert_pydantic_to_tool_spec
@@ -407,6 +409,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
407409

408410
# Format content blocks for Bedrock API compatibility
409411
formatted_content = self._format_request_message_content(content_block)
412+
if formatted_content is None:
413+
continue
410414

411415
# Wrap text or image content in guardrailContent if this is the last user message
412416
if (
@@ -459,7 +463,19 @@ def _should_include_tool_result_status(self) -> bool:
459463
else: # "auto"
460464
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)
461465

462-
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
466+
def _handle_location(self, location: SourceLocation) -> dict[str, Any] | None:
467+
"""Convert location content block to Bedrock format if its an S3Location."""
468+
if location["type"] == "s3":
469+
s3_location = cast(S3Location, location)
470+
formatted_document_s3: dict[str, Any] = {"uri": s3_location["uri"]}
471+
if "bucketOwner" in s3_location:
472+
formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"]
473+
return {"s3Location": formatted_document_s3}
474+
else:
475+
logger.warning("Non s3 location sources are not supported by Bedrock, skipping content block")
476+
return None
477+
478+
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any] | None:
463479
"""Format a Bedrock content block.
464480
465481
Bedrock strictly validates content blocks and throws exceptions for unknown fields.
@@ -489,17 +505,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
489505
if "format" in document:
490506
result["format"] = document["format"]
491507

492-
# Handle source - supports bytes or s3Location
508+
# Handle source - supports bytes or location
493509
if "source" in document:
494510
source = document["source"]
495-
if "s3Location" in source:
496-
s3_loc = source["s3Location"]
497-
formatted_document_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
498-
if "bucketOwner" in s3_loc:
499-
formatted_document_s3["bucketOwner"] = s3_loc["bucketOwner"]
500-
result["source"] = {"s3Location": formatted_document_s3}
511+
formatted_document_source: dict[str, Any] | None
512+
if "location" in source:
513+
formatted_document_source = self._handle_location(source["location"])
514+
if formatted_document_source is None:
515+
return None
501516
elif "bytes" in source:
502-
result["source"] = {"bytes": source["bytes"]}
517+
formatted_document_source = {"bytes": source["bytes"]}
518+
result["source"] = formatted_document_source
503519

504520
# Handle optional fields
505521
if "citations" in document and document["citations"] is not None:
@@ -520,13 +536,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
520536
if "image" in content:
521537
image = content["image"]
522538
source = image["source"]
523-
formatted_image_source: dict[str, Any] = {}
524-
if "s3Location" in source:
525-
s3_loc = source["s3Location"]
526-
formatted_image_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
527-
if "bucketOwner" in s3_loc:
528-
formatted_image_s3["bucketOwner"] = s3_loc["bucketOwner"]
529-
formatted_image_source = {"s3Location": formatted_image_s3}
539+
formatted_image_source: dict[str, Any] | None
540+
if "location" in source:
541+
formatted_image_source = self._handle_location(source["location"])
542+
if formatted_image_source is None:
543+
return None
530544
elif "bytes" in source:
531545
formatted_image_source = {"bytes": source["bytes"]}
532546
result = {"format": image["format"], "source": formatted_image_source}
@@ -564,9 +578,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
564578
# Handle json field since not in ContentBlock but valid in ToolResultContent
565579
formatted_content.append({"json": tool_result_content["json"]})
566580
else:
567-
formatted_content.append(
568-
self._format_request_message_content(cast(ContentBlock, tool_result_content))
581+
formatted_message_content = self._format_request_message_content(
582+
cast(ContentBlock, tool_result_content)
569583
)
584+
if formatted_message_content is None:
585+
continue
586+
formatted_content.append(formatted_message_content)
570587

571588
result = {
572589
"content": formatted_content,
@@ -591,13 +608,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
591608
if "video" in content:
592609
video = content["video"]
593610
source = video["source"]
594-
formatted_video_source: dict[str, Any] = {}
595-
if "s3Location" in source:
596-
s3_loc = source["s3Location"]
597-
formatted_video_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
598-
if "bucketOwner" in s3_loc:
599-
formatted_video_s3["bucketOwner"] = s3_loc["bucketOwner"]
600-
formatted_video_source = {"s3Location": formatted_video_s3}
611+
formatted_video_source: dict[str, Any] | None
612+
if "location" in source:
613+
formatted_video_source = self._handle_location(source["location"])
614+
if formatted_video_source is None:
615+
return None
601616
elif "bytes" in source:
602617
formatted_video_source = {"bytes": source["bytes"]}
603618
result = {"format": video["format"], "source": formatted_video_source}

src/strands/types/media.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
66
"""
77

8-
from typing import Literal
8+
from typing import Literal, TypeAlias
99

1010
from typing_extensions import Required, TypedDict
1111

@@ -15,34 +15,50 @@
1515
"""Supported document formats."""
1616

1717

18-
class S3Location(TypedDict, total=False):
18+
class Location(TypedDict, total=False):
19+
"""A location for a document.
20+
21+
This type is a generic location for a document. Its usage is determined by the underlying model provider.
22+
"""
23+
24+
type: Required[str]
25+
26+
27+
class S3Location(Location, total=False):
1928
"""A storage location in an Amazon S3 bucket.
2029
2130
Used by Bedrock to reference media files stored in S3 instead of passing raw bytes.
2231
2332
- Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html
2433
2534
Attributes:
35+
type: s3
2636
uri: An object URI starting with `s3://`. Required.
2737
bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional.
2838
"""
2939

40+
# mypy doesn't like overriding this field since its a subclass, but since its just a literal string, this is fine.
41+
42+
type: Literal["s3"] # type: ignore[misc]
3043
uri: Required[str]
3144
bucketOwner: str
3245

3346

47+
SourceLocation: TypeAlias = Location | S3Location
48+
49+
3450
class DocumentSource(TypedDict, total=False):
3551
"""Contains the content of a document.
3652
3753
Only one of `bytes` or `s3Location` should be specified.
3854
3955
Attributes:
4056
bytes: The binary content of the document.
41-
s3Location: S3 location of the document (Bedrock only).
57+
location: Location of the document.
4258
"""
4359

4460
bytes: bytes
45-
s3Location: S3Location
61+
location: SourceLocation
4662

4763

4864
class DocumentContent(TypedDict, total=False):
@@ -72,11 +88,11 @@ class ImageSource(TypedDict, total=False):
7288
7389
Attributes:
7490
bytes: The binary content of the image.
75-
s3Location: S3 location of the image (Bedrock only).
91+
location: Location of the image.
7692
"""
7793

7894
bytes: bytes
79-
s3Location: S3Location
95+
location: SourceLocation
8096

8197

8298
class ImageContent(TypedDict):
@@ -102,11 +118,11 @@ class VideoSource(TypedDict, total=False):
102118
103119
Attributes:
104120
bytes: The binary content of the video.
105-
s3Location: S3 location of the video (Bedrock only).
121+
location: Location of the video.
106122
"""
107123

108124
bytes: bytes
109-
s3Location: S3Location
125+
location: SourceLocation
110126

111127

112128
class VideoContent(TypedDict):

tests/strands/agent/hooks/test_events.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ def test_invocation_state_is_available_in_model_call_events(agent):
206206
assert after_event.invocation_state["request_id"] == "req-456"
207207

208208

209-
210-
211209
def test_before_invocation_event_messages_default_none(agent):
212210
"""Test that BeforeInvocationEvent.messages defaults to None for backward compatibility."""
213211
event = BeforeInvocationEvent(agent=agent)

tests/strands/models/test_bedrock.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
import logging
13
import os
24
import sys
35
import traceback
@@ -1519,7 +1521,6 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model
15191521
@pytest.mark.asyncio
15201522
async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
15211523
"""Test that stream method logs debug messages at the expected stages."""
1522-
import logging
15231524

15241525
# Set the logger to debug level to capture debug messages
15251526
caplog.set_level(logging.DEBUG, logger="strands.models.bedrock")
@@ -1797,28 +1798,18 @@ def test_format_request_image_s3_location_only(model, model_id):
17971798
"image": {
17981799
"format": "png",
17991800
"source": {
1800-
"s3Location": {"uri": "s3://my-bucket/image.png"},
1801+
"location": {"type": "s3", "uri": "s3://my-bucket/image.png"},
18011802
},
18021803
}
1803-
},
1804-
{
1805-
"image": {
1806-
"format": "png",
1807-
"source": {
1808-
"s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "12345"},
1809-
},
1810-
}
1811-
},
1804+
}
18121805
],
18131806
}
18141807
]
18151808

18161809
formatted_request = model._format_request(messages)
18171810
image_source = formatted_request["messages"][0]["content"][0]["image"]["source"]
1818-
image_source_with_bucket_owner = formatted_request["messages"][0]["content"][1]["image"]["source"]
18191811

18201812
assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}}
1821-
assert image_source_with_bucket_owner == {"s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "12345"}}
18221813

18231814

18241815
def test_format_request_image_bytes_only(model, model_id):
@@ -1854,7 +1845,7 @@ def test_format_request_document_s3_location(model, model_id):
18541845
"name": "report.pdf",
18551846
"format": "pdf",
18561847
"source": {
1857-
"s3Location": {"uri": "s3://my-bucket/report.pdf"},
1848+
"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"},
18581849
},
18591850
}
18601851
},
@@ -1863,7 +1854,11 @@ def test_format_request_document_s3_location(model, model_id):
18631854
"name": "report.pdf",
18641855
"format": "pdf",
18651856
"source": {
1866-
"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"},
1857+
"location": {
1858+
"type": "s3",
1859+
"uri": "s3://my-bucket/report.pdf",
1860+
"bucketOwner": "123456789012",
1861+
},
18671862
},
18681863
}
18691864
},
@@ -1882,25 +1877,67 @@ def test_format_request_document_s3_location(model, model_id):
18821877
}
18831878

18841879

1885-
def test_format_request_video_s3_location(model, model_id):
1886-
"""Test that video with s3Location is properly formatted."""
1880+
def test_format_request_unsupported_location(model, caplog):
1881+
"""Test that document with s3Location is properly formatted."""
1882+
1883+
caplog.set_level(logging.WARNING, logger="strands.models.bedrock")
1884+
18871885
messages = [
18881886
{
18891887
"role": "user",
18901888
"content": [
1889+
{"text": "Hello!"},
1890+
{
1891+
"document": {
1892+
"name": "report.pdf",
1893+
"format": "pdf",
1894+
"source": {
1895+
"location": {
1896+
"type": "other",
1897+
},
1898+
},
1899+
}
1900+
},
18911901
{
18921902
"video": {
18931903
"format": "mp4",
18941904
"source": {
1895-
"s3Location": {"uri": "s3://my-bucket/video.mp4"},
1905+
"location": {
1906+
"type": "other",
1907+
},
1908+
},
1909+
}
1910+
},
1911+
{
1912+
"image": {
1913+
"format": "png",
1914+
"source": {
1915+
"location": {
1916+
"type": "other",
1917+
},
18961918
},
18971919
}
18981920
},
1921+
],
1922+
}
1923+
]
1924+
1925+
formatted_request = model._format_request(messages)
1926+
assert len(formatted_request["messages"][0]["content"]) == 1
1927+
assert "Non s3 location sources are not supported by Bedrock, skipping content block" in caplog.text
1928+
1929+
1930+
def test_format_request_video_s3_location(model, model_id):
1931+
"""Test that video with s3Location is properly formatted."""
1932+
messages = [
1933+
{
1934+
"role": "user",
1935+
"content": [
18991936
{
19001937
"video": {
19011938
"format": "mp4",
19021939
"source": {
1903-
"s3Location": {"uri": "s3://my-bucket/video.mp4", "bucketOwner": "12345"},
1940+
"location": {"type": "s3", "uri": "s3://my-bucket/video.mp4"},
19041941
},
19051942
}
19061943
},
@@ -1910,10 +1947,8 @@ def test_format_request_video_s3_location(model, model_id):
19101947

19111948
formatted_request = model._format_request(messages)
19121949
video_source = formatted_request["messages"][0]["content"][0]["video"]["source"]
1913-
video_source_with_bucket_owner = formatted_request["messages"][0]["content"][1]["video"]["source"]
19141950

19151951
assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}}
1916-
assert video_source_with_bucket_owner == {"s3Location": {"uri": "s3://my-bucket/video.mp4", "bucketOwner": "12345"}}
19171952

19181953

19191954
def test_format_request_filters_document_content_blocks(model, model_id):
@@ -2413,7 +2448,6 @@ def test_inject_cache_point_skipped_for_non_claude(bedrock_client):
24132448

24142449
def test_format_bedrock_messages_does_not_mutate_original(bedrock_client):
24152450
"""Test that _format_bedrock_messages does not mutate original messages."""
2416-
import copy
24172451

24182452
model = BedrockModel(
24192453
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")

0 commit comments

Comments
 (0)