Skip to content
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,18 @@ class ThinkingBlock(BaseModel):
)


class ToolCallBlock(BaseModel):
block_type: Literal["tool_call"] = "tool_call"
tool_call_id: Optional[str] = Field(
default=None, description="ID of the tool call, if provided"
)
tool_name: str = Field(description="Name of the called tool")
tool_kwargs: dict[str, Any] | str = Field(
default_factory=dict, # type: ignore
description="Arguments provided to the tool, if available",
)


ContentBlock = Annotated[
Union[
TextBlock,
Expand All @@ -454,6 +466,7 @@ class ThinkingBlock(BaseModel):
CitableBlock,
CitationBlock,
ThinkingBlock,
ToolCallBlock,
],
Field(discriminator="block_type"),
]
Expand Down
5 changes: 3 additions & 2 deletions llama-index-core/llama_index/core/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CitableBlock,
CitationBlock,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.bridge.pydantic import (
BaseModel,
Expand Down Expand Up @@ -349,7 +350,7 @@ def _estimate_token_count(
] = []

for block in message_or_blocks.blocks:
if not isinstance(block, CachePoint):
if not isinstance(block, (CachePoint, ToolCallBlock)):
blocks.append(block)

# Estimate the token count for the additional kwargs
Expand All @@ -367,7 +368,7 @@ def _estimate_token_count(
blocks = []
for msg in messages:
for block in msg.blocks:
if not isinstance(block, CachePoint):
if not isinstance(block, (CachePoint, ToolCallBlock)):
blocks.append(block)

# Estimate the token count for the additional kwargs
Expand Down
17 changes: 17 additions & 0 deletions llama-index-core/tests/base/llms/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CachePoint,
CacheControl,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.bridge.pydantic import ValidationError
Expand Down Expand Up @@ -473,3 +474,19 @@ def test_thinking_block():
assert block.additional_information == {"total_thinking_tokens": 1000}
assert block.content == "hello world"
assert block.num_tokens == 100


def test_tool_call_block():
default_block = ToolCallBlock(tool_name="hello_world")
assert default_block.block_type == "tool_call"
assert default_block.tool_call_id is None
assert default_block.tool_name == "hello_world"
assert default_block.tool_kwargs == {}
custom_block = ToolCallBlock(
tool_name="hello_world",
tool_call_id="1",
tool_kwargs={"test": 1},
)
assert custom_block.tool_call_id == "1"
assert custom_block.tool_name == "hello_world"
assert custom_block.tool_kwargs == {"test": 1}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -12,6 +11,7 @@
Set,
Tuple,
Union,
cast,
)

from llama_index.core.base.llms.types import (
Expand All @@ -22,6 +22,7 @@
LLMMetadata,
MessageRole,
ContentBlock,
ToolCallBlock,
)
from llama_index.core.base.llms.types import TextBlock as LITextBlock
from llama_index.core.base.llms.types import CitationBlock as LICitationBlock
Expand All @@ -42,6 +43,7 @@
force_single_tool_call,
is_function_calling_model,
messages_to_anthropic_messages,
_anthropic_tool_call_to_tool_call_block,
)

import anthropic
Expand Down Expand Up @@ -344,8 +346,7 @@ def _completion_response_from_chat_response(

def _get_blocks_and_tool_calls_and_thinking(
self, response: Any
) -> Tuple[List[ContentBlock], List[Dict[str, Any]], List[Dict[str, Any]]]:
tool_calls = []
) -> Tuple[List[ContentBlock], List[Dict[str, Any]]]:
blocks: List[ContentBlock] = []
citations: List[TextCitation] = []
tracked_citations: Set[str] = set()
Expand Down Expand Up @@ -385,9 +386,15 @@ def _get_blocks_and_tool_calls_and_thinking(
)
)
elif isinstance(content_block, ToolUseBlock):
tool_calls.append(content_block.model_dump())
blocks.append(
ToolCallBlock(
tool_call_id=content_block.id,
tool_name=content_block.name,
tool_kwargs=content_block.input,
)
)

return blocks, tool_calls, [x.model_dump() for x in citations]
return blocks, [x.model_dump() for x in citations]

@llm_chat_callback()
def chat(
Expand All @@ -405,17 +412,12 @@ def chat(
**all_kwargs,
)

blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
response
)
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)

return AnthropicChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
},
),
citations=citations,
raw=dict(response),
Expand Down Expand Up @@ -526,7 +528,12 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
blocks=[
*content,
*_anthropic_tool_call_to_tool_call_block(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this assumes that tool calls always come after content (I think this is true? just flagging)

Copy link
Member Author

@AstraBert AstraBert Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that was my assumption: first there is the "explanation" of the tool call and then we have the tool call itself

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be safe, it might be best to just match the order that the content came in, rather than assuming 👍🏻

cur_tool_calls
),
],
additional_kwargs={
"tool_calls": [
t.model_dump() for t in tool_calls_to_send
Expand Down Expand Up @@ -577,17 +584,12 @@ async def achat(
**all_kwargs,
)

blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
response
)
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)

return AnthropicChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
},
),
citations=citations,
raw=dict(response),
Expand Down Expand Up @@ -698,11 +700,12 @@ async def gen() -> ChatResponseAsyncGen:
yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send],
"thinking": thinking.model_dump() if thinking else None,
},
blocks=[
*content,
*_anthropic_tool_call_to_tool_call_block(
cur_tool_calls
),
],
),
citations=cur_citations,
delta=content_delta,
Expand Down Expand Up @@ -811,7 +814,11 @@ def get_tool_calls_from_response(
**kwargs: Any,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]

if len(tool_calls) < 1:
if error_on_no_tool_call:
Expand All @@ -823,25 +830,13 @@ def get_tool_calls_from_response(

tool_selections = []
for tool_call in tool_calls:
if (
"input" not in tool_call
or "id" not in tool_call
or "name" not in tool_call
):
raise ValueError("Invalid tool call.")
if tool_call["type"] != "tool_use":
raise ValueError("Invalid tool type. Unsupported by Anthropic")
argument_dict = (
json.loads(tool_call["input"])
if isinstance(tool_call["input"], str)
else tool_call["input"]
)
argument_dict = tool_call.tool_kwargs

tool_selections.append(
ToolSelection(
tool_id=tool_call["id"],
tool_name=tool_call["name"],
tool_kwargs=argument_dict,
tool_id=tool_call.tool_call_id or "",
tool_name=tool_call.tool_name,
tool_kwargs=cast(Dict[str, Any], argument_dict),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Utility functions for the Anthropic SDK LLM integration.
"""

from typing import Any, Dict, List, Sequence, Tuple, Optional
from typing import Any, Dict, List, Sequence, Tuple, Optional, cast, Union

from llama_index.core.base.llms.types import (
ChatMessage,
Expand All @@ -15,6 +15,7 @@
CitableBlock,
CitationBlock,
ThinkingBlock,
ToolCallBlock,
ContentBlock,
)

Expand All @@ -26,6 +27,7 @@
ImageBlockParam,
CacheControlEphemeralParam,
Base64PDFSourceParam,
ToolUseBlock,
)
from anthropic.types import ContentBlockParam as AnthropicContentBlock
from anthropic.types.beta import (
Expand Down Expand Up @@ -198,6 +200,19 @@ def _to_anthropic_document_block(block: DocumentBlock) -> DocumentBlockParam:
)


def _anthropic_tool_call_to_tool_call_block(tool_calls: list[ToolUseBlock]):
blocks = []
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_kwargs=cast(Union[Dict[str, Any], str], tool_call.input),
tool_name=tool_call.name,
)
)
return blocks


def blocks_to_anthropic_blocks(
blocks: Sequence[ContentBlock], kwargs: dict[str, Any]
) -> List[AnthropicContentBlock]:
Expand Down Expand Up @@ -275,24 +290,18 @@ def blocks_to_anthropic_blocks(
elif isinstance(block, CitationBlock):
# No need to pass these back to Anthropic
continue
elif isinstance(block, ToolCallBlock):
anthropic_blocks.append(
ToolUseBlockParam(
id=block.tool_call_id or "",
input=block.tool_kwargs,
name=block.tool_name,
type="tool_use",
)
)
else:
raise ValueError(f"Unsupported block type: {type(block)}")

tool_calls = kwargs.get("tool_calls", [])
for tool_call in tool_calls:
assert "id" in tool_call
assert "input" in tool_call
assert "name" in tool_call

anthropic_blocks.append(
ToolUseBlockParam(
id=tool_call["id"],
input=tool_call["input"],
name=tool_call["name"],
type="tool_use",
)
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to add back this handling (in each LLM), just to ensure that old chat histories don't suddenly stop working. No harm in keeping this here

return anthropic_blocks


Expand Down Expand Up @@ -351,6 +360,12 @@ def messages_to_anthropic_messages(


def force_single_tool_call(response: ChatResponse) -> None:
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
]
if len(tool_calls) > 1:
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
response.message.blocks = [
block
for block in response.message.blocks
if not isinstance(block, ToolCallBlock)
] + [tool_calls[0]]
Loading
Loading