diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4ea1453a4..eeddaaa2a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -7,7 +7,17 @@ import json import logging import os -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Iterable, + Literal, + Optional, + Type, + TypeVar, + Union, +) import boto3 from botocore.config import Config as BotocoreConfig @@ -131,11 +141,18 @@ def __init__( else: new_user_agent = "strands-agents" - client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + client_config = boto_client_config.merge( + BotocoreConfig(user_agent_extra=new_user_agent) + ) else: client_config = BotocoreConfig(user_agent_extra="strands-agents") - resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION + resolved_region = ( + region_name + or session.region_name + or os.environ.get("AWS_REGION") + or DEFAULT_BEDROCK_REGION + ) self.client = session.client( service_name="bedrock-runtime", @@ -143,7 +160,9 @@ def __init__( region_name=resolved_region, ) - logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) + logger.debug( + "region=<%s> | bedrock client created", self.client.meta.region_name + ) @override def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore @@ -184,7 +203,11 @@ def format_request( "messages": self._format_bedrock_messages(messages), "system": [ *([{"text": system_prompt}] if system_prompt else []), - *([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []), + *( + [{"cachePoint": {"type": self.config["cache_prompt"]}}] + if self.config.get("cache_prompt") + else [] + ), ], **( { @@ -204,12 +227,20 @@ def format_request( else {} ), **( - {"additionalModelRequestFields": self.config["additional_request_fields"]} + { + "additionalModelRequestFields": self.config[ + "additional_request_fields" + ] + } if self.config.get("additional_request_fields") else {} ), **( - {"additionalModelResponseFieldPaths": self.config["additional_response_field_paths"]} + { + "additionalModelResponseFieldPaths": self.config[ + "additional_response_field_paths" + ] + } if self.config.get("additional_response_field_paths") else {} ), @@ -220,13 +251,18 @@ def format_request( "guardrailVersion": self.config["guardrail_version"], "trace": self.config.get("guardrail_trace", "enabled"), **( - {"streamProcessingMode": self.config.get("guardrail_stream_processing_mode")} + { + "streamProcessingMode": self.config.get( + "guardrail_stream_processing_mode" + ) + } if self.config.get("guardrail_stream_processing_mode") else {} ), } } - if self.config.get("guardrail_id") and self.config.get("guardrail_version") + if self.config.get("guardrail_id") + and self.config.get("guardrail_version") else {} ), "inferenceConfig": { @@ -241,7 +277,8 @@ def format_request( }, **( self.config["additional_args"] - if "additional_args" in self.config and self.config["additional_args"] is not None + if "additional_args" in self.config + and self.config["additional_args"] is not None else {} ), } @@ -278,7 +315,9 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: # Keep only the required fields for Bedrock cleaned_tool_result = ToolResult( - content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"] + content=tool_result["content"], + toolUseId=tool_result["toolUseId"], + status=tool_result["status"], ) cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result} @@ -288,7 +327,9 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: cleaned_content.append(content_block) # Create new message with cleaned content - cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) + cleaned_message: Message = Message( + content=cleaned_content, role=message["role"] + ) cleaned_messages.append(cleaned_message) return cleaned_messages @@ -306,11 +347,17 @@ def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: output_assessments = guardrail_data.get("outputAssessments", {}) # Check input assessments - if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()): + if any( + self._find_detected_and_blocked_policy(assessment) + for assessment in input_assessment.values() + ): return True # Check output assessments - if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()): + if any( + self._find_detected_and_blocked_policy(assessment) + for assessment in output_assessments.values() + ): return True return False @@ -341,7 +388,8 @@ def _generate_redaction_events(self) -> list[StreamEvent]: { "redactContent": { "redactAssistantContentMessage": self.config.get( - "guardrail_redact_output_message", "[Assistant output redacted.]" + "guardrail_redact_output_message", + "[Assistant output redacted.]", ) } } @@ -384,7 +432,9 @@ def callback(event: Optional[StreamEvent] = None) -> None: loop = asyncio.get_event_loop() queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() - thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt) + thread = asyncio.to_thread( + self._stream, callback, messages, tool_specs, system_prompt + ) task = asyncio.create_task(thread) while True: @@ -396,6 +446,18 @@ def callback(event: Optional[StreamEvent] = None) -> None: await task + def _strip_reasoning_content_from_message(self, message: Message) -> Message: + # Deep copy the message to avoid mutating original + import copy + + msg_copy = copy.deepcopy(message) + + content = msg_copy.get("content", []) + # Filter out any content blocks with reasoningContent + filtered_content = [c for c in content if "reasoningContent" not in c] + msg_copy["content"] = filtered_content + return msg_copy + def _stream( self, callback: Callable[..., None], @@ -418,8 +480,14 @@ def _stream( ContextWindowOverflowException: If the input exceeds the model's context window. ModelThrottledException: If the model service is throttling requests. """ - logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + logger.debug("stripping reasoning content from messages") + cleaned_messages = [ + self._strip_reasoning_content_from_message(m) for m in messages + ] + + logger.debug("formatting request with cleaned messages") + request = self.format_request(cleaned_messages, tool_specs, system_prompt) + logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -461,7 +529,10 @@ def _stream( if e.response["Error"]["Code"] == "ThrottlingException": raise ModelThrottledException(error_message) from e - if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): + if any( + overflow_message in error_message + for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES + ): logger.warning("bedrock threw context window overflow error") raise ContextWindowOverflowException(e) from e @@ -497,7 +568,9 @@ def _stream( callback() logger.debug("finished streaming response from model") - def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: + def _convert_non_streaming_to_streaming( + self, response: dict[str, Any] + ) -> Iterable[StreamEvent]: """Convert a non-streaming response to the streaming format. Args: @@ -527,7 +600,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera # For tool use, we need to yield the input as a delta input_value = json.dumps(content["toolUse"]["input"]) - yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}} + yield { + "contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}} + } elif "text" in content: # Then yield the text as a delta yield { @@ -539,7 +614,13 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera # Then yield the reasoning content as a delta yield { "contentBlockDelta": { - "delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}} + "delta": { + "reasoningContent": { + "text": content["reasoningContent"]["reasoningText"][ + "text" + ] + } + } } } @@ -548,7 +629,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera "contentBlockDelta": { "delta": { "reasoningContent": { - "signature": content["reasoningContent"]["reasoningText"]["signature"] + "signature": content["reasoningContent"][ + "reasoningText" + ]["signature"] } } } @@ -561,7 +644,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera yield { "messageStop": { "stopReason": response["stopReason"], - "additionalModelResponseFields": response.get("additionalModelResponseFields"), + "additionalModelResponseFields": response.get( + "additionalModelResponseFields" + ), } } @@ -589,7 +674,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: # Check if input is a dictionary if isinstance(input, dict): # Check if current dictionary has action: BLOCKED and detected: true - if input.get("action") == "BLOCKED" and input.get("detected") and isinstance(input.get("detected"), bool): + if ( + input.get("action") == "BLOCKED" + and input.get("detected") + and isinstance(input.get("detected"), bool) + ): return True # Recursively check all values in the dictionary @@ -609,7 +698,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. @@ -624,14 +717,21 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + **kwargs, + ) async for event in streaming.process_stream(response): yield event stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') + raise ValueError( + f'Model returned stop_reason: {stop_reason} instead of "tool_use".' + ) content = messages["content"] output_response: dict[str, Any] | None = None @@ -644,6 +744,8 @@ async def structured_output( continue if output_response is None: - raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + raise ValueError( + "No valid tool use or tool use input was found in the Bedrock response." + ) yield {"output": output_model(**output_response)}