Skip to content

Commit 5f6f85f

Browse files
committed
fix: Properly handle prompt=None & avoid agent hanging
bedrock.py now catches all exceptions in _stream so it no longer hangs when invalid content is passed. In addition, since we don't allow agent(None), go ahead and validate that none is not passed throughout our agent calls.
1 parent ec5304c commit 5f6f85f

File tree

5 files changed

+53
-11
lines changed

5 files changed

+53
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ test-integ = [
234234
"hatch test tests_integ {args}"
235235
]
236236
prepare = [
237-
"hatch fmt --linter",
238237
"hatch fmt --formatter",
238+
"hatch fmt --linter",
239239
"hatch run test-lint",
240240
"hatch test --all"
241241
]

src/strands/agent/agent.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> Age
367367
- message: The final message from the model
368368
- metrics: Performance metrics from the event loop
369369
- state: The final state of the event loop
370+
371+
Raises:
372+
ValueError: If prompt is None.
370373
"""
371374

372375
def execute() -> AgentResult:
@@ -393,6 +396,9 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
393396
- message: The final message from the model
394397
- metrics: Performance metrics from the event loop
395398
- state: The final state of the event loop
399+
400+
Raises:
401+
ValueError: If prompt is None.
396402
"""
397403
events = self.stream_async(prompt, **kwargs)
398404
async for event in events:
@@ -452,8 +458,8 @@ async def structured_output_async(
452458

453459
# add the prompt as the last message
454460
if prompt:
455-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
456-
self._append_message({"role": "user", "content": content})
461+
message = self._standardize_prompt(prompt)
462+
self._append_message(message)
457463

458464
events = self.model.structured_output(output_model, self.messages, system_prompt=self.system_prompt)
459465
async for event in events:
@@ -487,6 +493,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
487493
- And other event data provided by the callback handler
488494
489495
Raises:
496+
ValueError: If prompt is None.
490497
Exception: Any exceptions from the agent invocation will be propagated to the caller.
491498
492499
Example:
@@ -498,8 +505,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
498505
"""
499506
callback_handler = kwargs.get("callback_handler", self.callback_handler)
500507

501-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
502-
message: Message = {"role": "user", "content": content}
508+
message = self._standardize_prompt(prompt)
503509

504510
self.trace_span = self._start_agent_trace_span(message)
505511
with trace_api.use_span(self.trace_span):
@@ -561,6 +567,15 @@ async def _run_loop(
561567
self.conversation_manager.apply_management(self)
562568
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
563569

570+
def _standardize_prompt(self, prompt: Union[str, list[ContentBlock]]) -> Message:
571+
"""Convert the prompt into a Message, validating it along the way."""
572+
if prompt is None:
573+
raise ValueError("User prompt must not be None")
574+
575+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
576+
message: Message = {"role": "user", "content": content}
577+
return message
578+
564579
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
565580
"""Execute the event loop cycle with retry logic for context window limits.
566581

src/strands/models/bedrock.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,14 @@ def _stream(
418418
ContextWindowOverflowException: If the input exceeds the model's context window.
419419
ModelThrottledException: If the model service is throttling requests.
420420
"""
421-
logger.debug("formatting request")
422-
request = self.format_request(messages, tool_specs, system_prompt)
423-
logger.debug("request=<%s>", request)
421+
try:
422+
logger.debug("formatting request")
423+
request = self.format_request(messages, tool_specs, system_prompt)
424+
logger.debug("request=<%s>", request)
424425

425-
logger.debug("invoking model")
426-
streaming = self.config.get("streaming", True)
426+
logger.debug("invoking model")
427+
streaming = self.config.get("streaming", True)
427428

428-
try:
429429
logger.debug("got response from model")
430430
if streaming:
431431
response = self.client.converse_stream(**request)

tests/strands/agent/test_agent.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,24 @@ async def test_agent__call__in_async_context(mock_model, agent, agenerator):
750750
assert tru_message == exp_message
751751

752752

753+
@pytest.mark.asyncio
754+
async def test_agent_invocations_prompt_validation(agent, alist):
755+
with pytest.raises(ValueError):
756+
await agent.invoke_async(prompt=None)
757+
758+
with pytest.raises(ValueError):
759+
await agent(prompt=None)
760+
761+
with pytest.raises(ValueError):
762+
await alist(agent.stream_async(prompt=None))
763+
764+
with pytest.raises(ValueError):
765+
await agent.structured_output(type(user), prompt=None)
766+
767+
with pytest.raises(ValueError):
768+
await agent.structured_output_async(type(user), prompt=None)
769+
770+
753771
@pytest.mark.asyncio
754772
async def test_agent_invoke_async(mock_model, agent, agenerator):
755773
mock_model.mock_stream.return_value = agenerator(

tests/strands/models/test_bedrock.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,15 @@ async def test_stream_throttling_exception_from_event_stream_error(bedrock_clien
419419
)
420420

421421

422+
@pytest.mark.asyncio
423+
async def test_stream_with_invalid_content_throws(bedrock_client, model, alist):
424+
# We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642
425+
messages = [{"role": "user", "content": None}]
426+
427+
with pytest.raises(TypeError):
428+
await alist(model.stream(messages))
429+
430+
422431
@pytest.mark.asyncio
423432
async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist):
424433
error_message = "ThrottlingException: Rate exceeded for ConverseStream"

0 commit comments

Comments
 (0)