Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import (
Any,
Callable,
Expand All @@ -20,6 +21,8 @@
CompletionResponseGen,
LLMMetadata,
MessageRole,
TextBlock,
ThinkingBlock,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
Expand All @@ -46,6 +49,8 @@
join_two_dicts,
messages_to_converse_messages,
tools_to_converse_tools,
is_reasoning,
ThinkingDict,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -158,6 +163,10 @@ class BedrockConverse(FunctionCallingLLM):
trace: Optional[str] = Field(
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
)
thinking: Optional[ThinkingDict] = Field(
description="Specifies the thinking configuration of a reasoning model. Only applicable to Anthropic and DeepSeek models",
default=None,
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Additional kwargs for the bedrock invokeModel request.",
Expand Down Expand Up @@ -200,6 +209,7 @@ def __init__(
guardrail_version: Optional[str] = None,
application_inference_profile_arn: Optional[str] = None,
trace: Optional[str] = None,
thinking: Optional[ThinkingDict] = None,
) -> None:
additional_kwargs = additional_kwargs or {}
callback_manager = callback_manager or CallbackManager([])
Expand All @@ -213,6 +223,13 @@ def __init__(
"botocore_session": botocore_session,
}

if not is_reasoning(model) and thinking is not None:
thinking = None
warnings.warn(
"You set thinking parameters for a non-reasoning models, they will be ignored",
UserWarning,
)

super().__init__(
temperature=temperature,
max_tokens=max_tokens,
Expand Down Expand Up @@ -243,6 +260,7 @@ def __init__(
guardrail_version=guardrail_version,
application_inference_profile_arn=application_inference_profile_arn,
trace=trace,
thinking=thinking,
)

self._config = None
Expand Down Expand Up @@ -330,7 +348,9 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:

def _get_content_and_tool_calls(
self, response: Optional[Dict[str, Any]] = None, content: Dict[str, Any] = None
) -> Tuple[str, Dict[str, Any], List[str], List[str]]:
) -> Tuple[
List[Union[TextBlock, ThinkingBlock]], Dict[str, Any], List[str], List[str]
]:
assert response is not None or content is not None, (
f"Either response or content must be provided. Got response: {response}, content: {content}"
)
Expand All @@ -340,14 +360,25 @@ def _get_content_and_tool_calls(
tool_calls = []
tool_call_ids = []
status = []
text_content = ""
blocks = []
if content is not None:
content_list = [content]
else:
content_list = response["output"]["message"]["content"]
for content_block in content_list:
if text := content_block.get("text", None):
text_content += text
blocks.append(TextBlock(text=text))
if thinking := content_block.get("reasoningContent", None):
blocks.append(
ThinkingBlock(
content=thinking.get("reasoningText", {}).get("text", None),
additional_information={
"signature": thinking.get("reasoningText", {}).get(
"signature", None
)
},
)
)
if tool_usage := content_block.get("toolUse", None):
if "toolUseId" not in tool_usage:
tool_usage["toolUseId"] = content_block["toolUseId"]
Expand All @@ -361,7 +392,7 @@ def _get_content_and_tool_calls(
tool_call_ids.append(tool_result_content.get("toolUseId", ""))
status.append(tool_result.get("status", ""))

return text_content, tool_calls, tool_call_ids, status
return blocks, tool_calls, tool_call_ids, status

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
Expand All @@ -370,6 +401,8 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
messages, self.model
)
all_kwargs = self._get_all_kwargs(**kwargs)
if self.thinking is not None:
all_kwargs["thinking"] = self.thinking

# invoke LLM in AWS Bedrock Converse with retry
response = converse_with_retry(
Expand All @@ -386,14 +419,14 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
**all_kwargs,
)

content, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls(
blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls(
response
)

return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=content,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
"tool_call_id": tool_call_ids,
Expand All @@ -420,6 +453,8 @@ def stream_chat(
messages, self.model
)
all_kwargs = self._get_all_kwargs(**kwargs)
if self.thinking is not None:
all_kwargs["thinking"] = self.thinking

# invoke LLM in AWS Bedrock Converse with retry
response = converse_with_retry(
Expand All @@ -446,6 +481,12 @@ def gen() -> ChatResponseGen:
if content_block_delta := chunk.get("contentBlockDelta"):
content_delta = content_block_delta["delta"]
content = join_two_dicts(content, content_delta)
thinking = ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it only returning the full thinking/reasoning text? Or is it streaming?

If its streaming, setting thinking = "" means we are removing the complete thinking string right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to update the tests to check if the thinking is over, like, 50 chars?

Copy link
Member Author

@AstraBert AstraBert Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I tested the streaming behavior, this basically means that we have separate thinking chunks for each streamed response, instead of an incrementally growing response. So, rather than this:

The quick brown fox j
The quick brown fox jumps over th
The quick brown fox jumps over the lazy dog

We would have this:

The quick brown fox
jumps over the
lazy dog

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the test for the streaming of thinking blocks does not check for the character length, it checks for the number of thinking blocks produced (which should be non-zero). I will add a test for character length, tho, to test the streaming


if "reasoningContent" in content_delta:
thinking += content_delta.get("reasoningContent", {}).get(
"text", ""
)

# If this delta contains tool call info, update current tool call
if "toolUse" in content_delta:
Expand Down Expand Up @@ -476,11 +517,15 @@ def gen() -> ChatResponseGen:
current_tool_call = join_two_dicts(
current_tool_call, tool_use_delta
)

blocks: List[Union[TextBlock, ThinkingBlock]] = [
TextBlock(text=content.get("text", ""))
]
if thinking != "":
blocks.append(ThinkingBlock(content=thinking))
yield ChatResponse(
message=ChatMessage(
role=role,
content=content.get("text", ""),
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
"tool_call_id": [
Expand Down Expand Up @@ -577,14 +622,14 @@ async def achat(
**all_kwargs,
)

content, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls(
blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls(
response
)

return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=content,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
"tool_call_id": tool_call_ids,
Expand All @@ -611,6 +656,8 @@ async def astream_chat(
messages, self.model
)
all_kwargs = self._get_all_kwargs(**kwargs)
if self.thinking is not None:
all_kwargs["thinking"] = self.thinking

# invoke LLM in AWS Bedrock Converse with retry
response_gen = await converse_with_retry_async(
Expand Down Expand Up @@ -639,6 +686,12 @@ async def gen() -> ChatResponseAsyncGen:
if content_block_delta := chunk.get("contentBlockDelta"):
content_delta = content_block_delta["delta"]
content = join_two_dicts(content, content_delta)
thinking = ""

if "reasoningContent" in content_delta:
thinking += content_block_delta.get("reasoningContent", {}).get(
"text", ""
)

# If this delta contains tool call info, update current tool call
if "toolUse" in content_delta:
Expand Down Expand Up @@ -669,11 +722,15 @@ async def gen() -> ChatResponseAsyncGen:
current_tool_call = join_two_dicts(
current_tool_call, tool_use_delta
)

blocks: List[Union[TextBlock, ThinkingBlock]] = [
TextBlock(text=content.get("text", ""))
]
if thinking != "":
blocks.append(ThinkingBlock(content=thinking))
yield ChatResponse(
message=ChatMessage(
role=role,
content=content.get("text", ""),
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
"tool_call_id": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import base64
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Literal,
Union,
)
from typing_extensions import TypedDict
from tenacity import (
before_sleep_log,
retry,
Expand All @@ -20,6 +31,7 @@
AudioBlock,
DocumentBlock,
CachePoint,
ThinkingBlock,
)


Expand Down Expand Up @@ -151,6 +163,18 @@
"amazon.nova-micro-v1:0",
)

BEDROCK_REASONING_MODELS = (
"anthropic.claude-3-7-sonnet-20250219-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",
"deepseek.r1-v1:0",
)


def is_reasoning(model_name: str) -> bool:
model_name = get_model_name(model_name)
return model_name in BEDROCK_REASONING_MODELS


def get_model_name(model_name: str) -> str:
"""Extract base model name from region-prefixed model identifier."""
Expand Down Expand Up @@ -220,6 +244,13 @@ def _content_block_to_bedrock_format(
return {
"text": block.text,
}
elif isinstance(block, ThinkingBlock):
if block.content:
return {
"text": block.content,
}
else:
return None
elif isinstance(block, DocumentBlock):
if not block.data:
file_buffer = block.resolve_document()
Expand Down Expand Up @@ -518,6 +549,10 @@ def converse_with_retry(
"temperature": temperature,
},
}
if "thinking" in kwargs:
converse_kwargs["additionalModelRequestFields"] = {
"thinking": kwargs["thinking"]
}
if system_prompt:
if isinstance(system_prompt, str):
# if the system prompt is a simple text (for retro compatibility)
Expand Down Expand Up @@ -547,7 +582,14 @@ def converse_with_retry(
{
k: v
for k, v in kwargs.items()
if k not in ["tools", "guardrail_identifier", "guardrail_version", "trace"]
if k
not in [
"tools",
"guardrail_identifier",
"guardrail_version",
"trace",
"thinking",
]
},
)

Expand Down Expand Up @@ -688,3 +730,8 @@ def join_two_dicts(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, An
else:
new_dict[key] += value
return new_dict


class ThinkingDict(TypedDict):
type: Literal["enabled"]
budget_tokens: int
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dev = [

[project]
name = "llama-index-llms-bedrock-converse"
version = "0.9.5"
version = "0.10.0"
description = "llama-index llms bedrock converse integration"
authors = [{name = "Your Name", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand All @@ -38,7 +38,7 @@ license = "MIT"
dependencies = [
"boto3>=1.38.27,<2",
"aioboto3>=15.0.0,<16",
"llama-index-core>=0.13.0,<0.15",
"llama-index-core>=0.14.3,<0.15",
]

[tool.codespell]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TextBlock,
CacheControl,
CachePoint,
ThinkingBlock,
ChatMessage,
)
from llama_index.core.tools import FunctionTool
Expand Down Expand Up @@ -75,6 +76,12 @@ def test_content_block_to_bedrock_format_text():
assert result == {"text": "Hello, world!"}


def test_content_block_to_bedrock_format_thinking():
think_block = ThinkingBlock(content="Hello, world!")
result = _content_block_to_bedrock_format(think_block, MessageRole.USER)
assert result == {"text": "Hello, world!"}


def test_cache_point_block():
cache_point = CachePoint(cache_control=CacheControl(type="default"))
result = _content_block_to_bedrock_format(cache_point, MessageRole.USER)
Expand Down
Loading
Loading