-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Adding ThinkingBlock to Ollama and Bedrock Converse #19936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
585b355
0f9c6be
659d5dc
1f66fb7
17cca1d
07ec9d7
2541823
18f3cf5
38b940e
6b8d786
7532d20
22dc500
f39062c
bbf087e
caa0e4e
cfb716e
6b6c2b0
fa0d8b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import warnings | ||
from typing import ( | ||
Any, | ||
Callable, | ||
|
@@ -20,6 +21,8 @@ | |
CompletionResponseGen, | ||
LLMMetadata, | ||
MessageRole, | ||
TextBlock, | ||
ThinkingBlock, | ||
) | ||
from llama_index.core.bridge.pydantic import Field, PrivateAttr | ||
from llama_index.core.callbacks import CallbackManager | ||
|
@@ -46,6 +49,8 @@ | |
join_two_dicts, | ||
messages_to_converse_messages, | ||
tools_to_converse_tools, | ||
is_reasoning, | ||
ThinkingDict, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -158,6 +163,10 @@ class BedrockConverse(FunctionCallingLLM): | |
trace: Optional[str] = Field( | ||
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace." | ||
) | ||
thinking: Optional[ThinkingDict] = Field( | ||
description="Specifies the thinking configuration of a reasoning model. Only applicable to Anthropic and DeepSeek models", | ||
default=None, | ||
) | ||
additional_kwargs: Dict[str, Any] = Field( | ||
default_factory=dict, | ||
description="Additional kwargs for the bedrock invokeModel request.", | ||
|
@@ -200,6 +209,7 @@ def __init__( | |
guardrail_version: Optional[str] = None, | ||
application_inference_profile_arn: Optional[str] = None, | ||
trace: Optional[str] = None, | ||
thinking: Optional[ThinkingDict] = None, | ||
) -> None: | ||
additional_kwargs = additional_kwargs or {} | ||
callback_manager = callback_manager or CallbackManager([]) | ||
|
@@ -213,6 +223,13 @@ def __init__( | |
"botocore_session": botocore_session, | ||
} | ||
|
||
if not is_reasoning(model) and thinking is not None: | ||
thinking = None | ||
warnings.warn( | ||
"You set thinking parameters for a non-reasoning models, they will be ignored", | ||
UserWarning, | ||
) | ||
|
||
super().__init__( | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
|
@@ -243,6 +260,7 @@ def __init__( | |
guardrail_version=guardrail_version, | ||
application_inference_profile_arn=application_inference_profile_arn, | ||
trace=trace, | ||
thinking=thinking, | ||
) | ||
|
||
self._config = None | ||
|
@@ -330,7 +348,9 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: | |
|
||
def _get_content_and_tool_calls( | ||
self, response: Optional[Dict[str, Any]] = None, content: Dict[str, Any] = None | ||
) -> Tuple[str, Dict[str, Any], List[str], List[str]]: | ||
) -> Tuple[ | ||
List[Union[TextBlock, ThinkingBlock]], Dict[str, Any], List[str], List[str] | ||
]: | ||
assert response is not None or content is not None, ( | ||
f"Either response or content must be provided. Got response: {response}, content: {content}" | ||
) | ||
|
@@ -340,14 +360,25 @@ def _get_content_and_tool_calls( | |
tool_calls = [] | ||
tool_call_ids = [] | ||
status = [] | ||
text_content = "" | ||
blocks = [] | ||
if content is not None: | ||
content_list = [content] | ||
else: | ||
content_list = response["output"]["message"]["content"] | ||
for content_block in content_list: | ||
if text := content_block.get("text", None): | ||
text_content += text | ||
blocks.append(TextBlock(text=text)) | ||
if thinking := content_block.get("reasoningContent", None): | ||
blocks.append( | ||
ThinkingBlock( | ||
content=thinking.get("reasoningText", {}).get("text", None), | ||
additional_information={ | ||
"signature": thinking.get("reasoningText", {}).get( | ||
"signature", None | ||
) | ||
}, | ||
) | ||
) | ||
if tool_usage := content_block.get("toolUse", None): | ||
if "toolUseId" not in tool_usage: | ||
tool_usage["toolUseId"] = content_block["toolUseId"] | ||
|
@@ -361,13 +392,15 @@ def _get_content_and_tool_calls( | |
tool_call_ids.append(tool_result_content.get("toolUseId", "")) | ||
status.append(tool_result.get("status", "")) | ||
|
||
return text_content, tool_calls, tool_call_ids, status | ||
return blocks, tool_calls, tool_call_ids, status | ||
|
||
@llm_chat_callback() | ||
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: | ||
# convert Llama Index messages to AWS Bedrock Converse messages | ||
converse_messages, system_prompt = messages_to_converse_messages(messages) | ||
all_kwargs = self._get_all_kwargs(**kwargs) | ||
if self.thinking is not None: | ||
all_kwargs["thinking"] = self.thinking | ||
|
||
# invoke LLM in AWS Bedrock Converse with retry | ||
response = converse_with_retry( | ||
|
@@ -384,14 +417,14 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: | |
**all_kwargs, | ||
) | ||
|
||
content, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls( | ||
blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls( | ||
response | ||
) | ||
|
||
return ChatResponse( | ||
message=ChatMessage( | ||
role=MessageRole.ASSISTANT, | ||
content=content, | ||
blocks=blocks, | ||
additional_kwargs={ | ||
"tool_calls": tool_calls, | ||
"tool_call_id": tool_call_ids, | ||
|
@@ -416,6 +449,8 @@ def stream_chat( | |
# convert Llama Index messages to AWS Bedrock Converse messages | ||
converse_messages, system_prompt = messages_to_converse_messages(messages) | ||
all_kwargs = self._get_all_kwargs(**kwargs) | ||
if self.thinking is not None: | ||
all_kwargs["thinking"] = self.thinking | ||
|
||
# invoke LLM in AWS Bedrock Converse with retry | ||
response = converse_with_retry( | ||
|
@@ -442,6 +477,12 @@ def gen() -> ChatResponseGen: | |
if content_block_delta := chunk.get("contentBlockDelta"): | ||
content_delta = content_block_delta["delta"] | ||
content = join_two_dicts(content, content_delta) | ||
thinking = "" | ||
|
||
if "reasoningContent" in content_delta: | ||
thinking += content_delta.get("reasoningContent", {}).get( | ||
"text", "" | ||
) | ||
|
||
# If this delta contains tool call info, update current tool call | ||
if "toolUse" in content_delta: | ||
|
@@ -472,11 +513,15 @@ def gen() -> ChatResponseGen: | |
current_tool_call = join_two_dicts( | ||
current_tool_call, tool_use_delta | ||
) | ||
|
||
blocks: List[Union[TextBlock, ThinkingBlock]] = [ | ||
TextBlock(text=content.get("text", "")) | ||
] | ||
if thinking != "": | ||
blocks.append(ThinkingBlock(content=thinking)) | ||
yield ChatResponse( | ||
message=ChatMessage( | ||
role=role, | ||
content=content.get("text", ""), | ||
blocks=blocks, | ||
additional_kwargs={ | ||
"tool_calls": tool_calls, | ||
"tool_call_id": [ | ||
|
@@ -485,7 +530,9 @@ def gen() -> ChatResponseGen: | |
"status": [], # Will be populated when tool results come in | ||
}, | ||
), | ||
delta=content_delta.get("text", ""), | ||
delta=content_delta.get("text", None) | ||
or content_delta.get("thinking", None) | ||
or "", | ||
|
||
raw=chunk, | ||
additional_kwargs=self._get_response_token_counts(dict(chunk)), | ||
) | ||
|
@@ -571,14 +618,14 @@ async def achat( | |
**all_kwargs, | ||
) | ||
|
||
content, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls( | ||
blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls( | ||
response | ||
) | ||
|
||
return ChatResponse( | ||
message=ChatMessage( | ||
role=MessageRole.ASSISTANT, | ||
content=content, | ||
blocks=blocks, | ||
additional_kwargs={ | ||
"tool_calls": tool_calls, | ||
"tool_call_id": tool_call_ids, | ||
|
@@ -603,6 +650,8 @@ async def astream_chat( | |
# convert Llama Index messages to AWS Bedrock Converse messages | ||
converse_messages, system_prompt = messages_to_converse_messages(messages) | ||
all_kwargs = self._get_all_kwargs(**kwargs) | ||
if self.thinking is not None: | ||
all_kwargs["thinking"] = self.thinking | ||
|
||
# invoke LLM in AWS Bedrock Converse with retry | ||
response_gen = await converse_with_retry_async( | ||
|
@@ -631,6 +680,12 @@ async def gen() -> ChatResponseAsyncGen: | |
if content_block_delta := chunk.get("contentBlockDelta"): | ||
content_delta = content_block_delta["delta"] | ||
content = join_two_dicts(content, content_delta) | ||
thinking = "" | ||
|
||
if "reasoningContent" in content_delta: | ||
thinking += content_block_delta.get("reasoningContent", {}).get( | ||
"text", "" | ||
) | ||
|
||
# If this delta contains tool call info, update current tool call | ||
if "toolUse" in content_delta: | ||
|
@@ -661,11 +716,15 @@ async def gen() -> ChatResponseAsyncGen: | |
current_tool_call = join_two_dicts( | ||
current_tool_call, tool_use_delta | ||
) | ||
|
||
blocks: List[Union[TextBlock, ThinkingBlock]] = [ | ||
TextBlock(text=content.get("text", "")) | ||
] | ||
if thinking != "": | ||
blocks.append(ThinkingBlock(content=thinking)) | ||
yield ChatResponse( | ||
message=ChatMessage( | ||
role=role, | ||
content=content.get("text", ""), | ||
blocks=blocks, | ||
additional_kwargs={ | ||
"tool_calls": tool_calls, | ||
"tool_call_id": [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ dev = [ | |
|
||
[project] | ||
name = "llama-index-llms-bedrock-converse" | ||
version = "0.9.2" | ||
version = "0.10.0" | ||
description = "llama-index llms bedrock converse integration" | ||
authors = [{name = "Your Name", email = "[email protected]"}] | ||
requires-python = ">=3.9,<4.0" | ||
|
@@ -38,7 +38,7 @@ license = "MIT" | |
dependencies = [ | ||
"boto3>=1.34.122,<2", | ||
"aioboto3>=13.1.1,<16", | ||
"llama-index-core>=0.13.0,<0.15", | ||
"llama-index-core>=0.14.3,<0.15", | ||
] | ||
|
||
[tool.codespell] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it only returning the full thinking/reasoning text? Or is it streaming?
If its streaming, setting
thinking = ""
means we are removing the complete thinking string right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to update the tests to check if the thinking is over, like, 50 chars?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I tested the streaming behavior, this basically means that we have separate thinking chunks for each streamed response, instead of an incrementally growing response. So, rather than this:
We would have this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the test for the streaming of thinking blocks does not check for the character length, it checks for the number of thinking blocks produced (which should be non-zero). I will add a test for character length, tho, to test the streaming