diff --git a/pyproject.toml b/pyproject.toml index 586a956af..d4a4b79dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,8 +234,8 @@ test-integ = [ "hatch test tests_integ {args}" ] prepare = [ - "hatch fmt --linter", "hatch fmt --formatter", + "hatch fmt --linter", "hatch run test-lint", "hatch test --all" ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4ea1453a4..ace35640a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -418,14 +418,14 @@ def _stream( ContextWindowOverflowException: If the input exceeds the model's context window. ModelThrottledException: If the model service is throttling requests. """ - logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) - logger.debug("request=<%s>", request) + try: + logger.debug("formatting request") + request = self.format_request(messages, tool_specs, system_prompt) + logger.debug("request=<%s>", request) - logger.debug("invoking model") - streaming = self.config.get("streaming", True) + logger.debug("invoking model") + streaming = self.config.get("streaming", True) - try: logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 0a2846adf..09e508845 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -419,6 +419,15 @@ async def test_stream_throttling_exception_from_event_stream_error(bedrock_clien ) +@pytest.mark.asyncio +async def test_stream_with_invalid_content_throws(bedrock_client, model, alist): + # We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642 + messages = [{"role": "user", "content": None}] + + with pytest.raises(TypeError): + await alist(model.stream(messages)) + + @pytest.mark.asyncio async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist): error_message = "ThrottlingException: Rate exceeded for ConverseStream"