diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index 479f4c0df3..8ca960cfca 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -443,6 +443,18 @@ class ThinkingBlock(BaseModel): ) +class ToolCallBlock(BaseModel): + block_type: Literal["tool_call"] = "tool_call" + tool_call_id: Optional[str] = Field( + default=None, description="ID of the tool call, if provided" + ) + tool_name: str = Field(description="Name of the called tool") + tool_kwargs: dict[str, Any] | str = Field( + default_factory=dict, # type: ignore + description="Arguments provided to the tool, if available", + ) + + ContentBlock = Annotated[ Union[ TextBlock, @@ -454,6 +466,7 @@ class ThinkingBlock(BaseModel): CitableBlock, CitationBlock, ThinkingBlock, + ToolCallBlock, ], Field(discriminator="block_type"), ] diff --git a/llama-index-core/llama_index/core/memory/memory.py b/llama-index-core/llama_index/core/memory/memory.py index a1365fba13..335b35be1f 100644 --- a/llama-index-core/llama_index/core/memory/memory.py +++ b/llama-index-core/llama_index/core/memory/memory.py @@ -29,6 +29,7 @@ CitableBlock, CitationBlock, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.bridge.pydantic import ( BaseModel, @@ -349,7 +350,7 @@ def _estimate_token_count( ] = [] for block in message_or_blocks.blocks: - if not isinstance(block, CachePoint): + if not isinstance(block, (CachePoint, ToolCallBlock)): blocks.append(block) # Estimate the token count for the additional kwargs @@ -367,7 +368,7 @@ def _estimate_token_count( blocks = [] for msg in messages: for block in msg.blocks: - if not isinstance(block, CachePoint): + if not isinstance(block, (CachePoint, ToolCallBlock)): blocks.append(block) # Estimate the token count for the additional kwargs diff --git a/llama-index-core/tests/base/llms/test_types.py b/llama-index-core/tests/base/llms/test_types.py index b4ed17216b..edeab7b900 100644 --- a/llama-index-core/tests/base/llms/test_types.py +++ b/llama-index-core/tests/base/llms/test_types.py @@ -18,6 +18,7 @@ CachePoint, CacheControl, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.bridge.pydantic import BaseModel from llama_index.core.bridge.pydantic import ValidationError @@ -473,3 +474,19 @@ def test_thinking_block(): assert block.additional_information == {"total_thinking_tokens": 1000} assert block.content == "hello world" assert block.num_tokens == 100 + + +def test_tool_call_block(): + default_block = ToolCallBlock(tool_name="hello_world") + assert default_block.block_type == "tool_call" + assert default_block.tool_call_id is None + assert default_block.tool_name == "hello_world" + assert default_block.tool_kwargs == {} + custom_block = ToolCallBlock( + tool_name="hello_world", + tool_call_id="1", + tool_kwargs={"test": 1}, + ) + assert custom_block.tool_call_id == "1" + assert custom_block.tool_name == "hello_world" + assert custom_block.tool_kwargs == {"test": 1} diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py index 51b93c700c..6615c8ea55 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py @@ -1,4 +1,3 @@ -import json from typing import ( TYPE_CHECKING, Any, @@ -12,6 +11,7 @@ Set, Tuple, Union, + cast, ) from llama_index.core.base.llms.types import ( @@ -22,6 +22,7 @@ LLMMetadata, MessageRole, ContentBlock, + ToolCallBlock, ) from llama_index.core.base.llms.types import TextBlock as LITextBlock from llama_index.core.base.llms.types import CitationBlock as LICitationBlock @@ -42,6 +43,7 @@ force_single_tool_call, is_function_calling_model, messages_to_anthropic_messages, + _anthropic_tool_call_to_tool_call_block, ) import anthropic @@ -344,8 +346,7 @@ def _completion_response_from_chat_response( def _get_blocks_and_tool_calls_and_thinking( self, response: Any - ) -> Tuple[List[ContentBlock], List[Dict[str, Any]], List[Dict[str, Any]]]: - tool_calls = [] + ) -> Tuple[List[ContentBlock], List[Dict[str, Any]]]: blocks: List[ContentBlock] = [] citations: List[TextCitation] = [] tracked_citations: Set[str] = set() @@ -385,9 +386,15 @@ def _get_blocks_and_tool_calls_and_thinking( ) ) elif isinstance(content_block, ToolUseBlock): - tool_calls.append(content_block.model_dump()) + blocks.append( + ToolCallBlock( + tool_call_id=content_block.id, + tool_name=content_block.name, + tool_kwargs=content_block.input, + ) + ) - return blocks, tool_calls, [x.model_dump() for x in citations] + return blocks, [x.model_dump() for x in citations] @llm_chat_callback() def chat( @@ -405,17 +412,12 @@ def chat( **all_kwargs, ) - blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking( - response - ) + blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response) return AnthropicChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, blocks=blocks, - additional_kwargs={ - "tool_calls": tool_calls, - }, ), citations=citations, raw=dict(response), @@ -526,7 +528,12 @@ def gen() -> Generator[AnthropicChatResponse, None, None]: yield AnthropicChatResponse( message=ChatMessage( role=role, - blocks=content, + blocks=[ + *content, + *_anthropic_tool_call_to_tool_call_block( + cur_tool_calls + ), + ], additional_kwargs={ "tool_calls": [ t.model_dump() for t in tool_calls_to_send @@ -577,17 +584,12 @@ async def achat( **all_kwargs, ) - blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking( - response - ) + blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response) return AnthropicChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, blocks=blocks, - additional_kwargs={ - "tool_calls": tool_calls, - }, ), citations=citations, raw=dict(response), @@ -698,11 +700,12 @@ async def gen() -> ChatResponseAsyncGen: yield AnthropicChatResponse( message=ChatMessage( role=role, - blocks=content, - additional_kwargs={ - "tool_calls": [t.dict() for t in tool_calls_to_send], - "thinking": thinking.model_dump() if thinking else None, - }, + blocks=[ + *content, + *_anthropic_tool_call_to_tool_call_block( + cur_tool_calls + ), + ], ), citations=cur_citations, delta=content_delta, @@ -811,7 +814,11 @@ def get_tool_calls_from_response( **kwargs: Any, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: @@ -823,25 +830,13 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: - if ( - "input" not in tool_call - or "id" not in tool_call - or "name" not in tool_call - ): - raise ValueError("Invalid tool call.") - if tool_call["type"] != "tool_use": - raise ValueError("Invalid tool type. Unsupported by Anthropic") - argument_dict = ( - json.loads(tool_call["input"]) - if isinstance(tool_call["input"], str) - else tool_call["input"] - ) + argument_dict = tool_call.tool_kwargs tool_selections.append( ToolSelection( - tool_id=tool_call["id"], - tool_name=tool_call["name"], - tool_kwargs=argument_dict, + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, + tool_kwargs=cast(Dict[str, Any], argument_dict), ) ) diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py index abc1a5d01e..4650fa0521 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py @@ -2,7 +2,7 @@ Utility functions for the Anthropic SDK LLM integration. """ -from typing import Any, Dict, List, Sequence, Tuple, Optional +from typing import Any, Dict, List, Sequence, Tuple, Optional, cast, Union from llama_index.core.base.llms.types import ( ChatMessage, @@ -15,6 +15,7 @@ CitableBlock, CitationBlock, ThinkingBlock, + ToolCallBlock, ContentBlock, ) @@ -26,6 +27,7 @@ ImageBlockParam, CacheControlEphemeralParam, Base64PDFSourceParam, + ToolUseBlock, ) from anthropic.types import ContentBlockParam as AnthropicContentBlock from anthropic.types.beta import ( @@ -198,6 +200,19 @@ def _to_anthropic_document_block(block: DocumentBlock) -> DocumentBlockParam: ) +def _anthropic_tool_call_to_tool_call_block(tool_calls: list[ToolUseBlock]): + blocks = [] + for tool_call in tool_calls: + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_kwargs=cast(Union[Dict[str, Any], str], tool_call.input), + tool_name=tool_call.name, + ) + ) + return blocks + + def blocks_to_anthropic_blocks( blocks: Sequence[ContentBlock], kwargs: dict[str, Any] ) -> List[AnthropicContentBlock]: @@ -275,24 +290,18 @@ def blocks_to_anthropic_blocks( elif isinstance(block, CitationBlock): # No need to pass these back to Anthropic continue + elif isinstance(block, ToolCallBlock): + anthropic_blocks.append( + ToolUseBlockParam( + id=block.tool_call_id or "", + input=block.tool_kwargs, + name=block.tool_name, + type="tool_use", + ) + ) else: raise ValueError(f"Unsupported block type: {type(block)}") - tool_calls = kwargs.get("tool_calls", []) - for tool_call in tool_calls: - assert "id" in tool_call - assert "input" in tool_call - assert "name" in tool_call - - anthropic_blocks.append( - ToolUseBlockParam( - id=tool_call["id"], - input=tool_call["input"], - name=tool_call["name"], - type="tool_use", - ) - ) - return anthropic_blocks @@ -351,6 +360,12 @@ def messages_to_anthropic_messages( def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block for block in response.message.blocks if isinstance(block, ToolCallBlock) + ] if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + response.message.blocks = [ + block + for block in response.message.blocks + if not isinstance(block, ToolCallBlock) + ] + [tool_calls[0]] diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py b/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py index 14f3be7435..830ed8fbc0 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py @@ -17,7 +17,7 @@ CachePoint, CacheControl, ) -from llama_index.core.base.llms.types import ThinkingBlock +from llama_index.core.base.llms.types import ThinkingBlock, ToolCallBlock from llama_index.core.tools import FunctionTool from llama_index.llms.anthropic import Anthropic from llama_index.llms.anthropic.base import AnthropicChatResponse @@ -253,8 +253,16 @@ def test_tool_required(): tool_required=True, ) assert isinstance(response, AnthropicChatResponse) - assert response.message.additional_kwargs["tool_calls"] is not None - assert len(response.message.additional_kwargs["tool_calls"]) > 0 + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + > 0 + ) # Test with tool_required=False response = llm.chat_with_tools( @@ -264,7 +272,16 @@ def test_tool_required(): ) assert isinstance(response, AnthropicChatResponse) # Should not use tools for a simple greeting - assert not response.message.additional_kwargs.get("tool_calls") + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 0 + ) # should not blow up with no tools (regression test) response = llm.chat_with_tools( @@ -273,7 +290,16 @@ def test_tool_required(): tool_required=False, ) assert isinstance(response, AnthropicChatResponse) - assert not response.message.additional_kwargs.get("tool_calls") + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 0 + ) @pytest.mark.skipif( diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py index 98ca8141a8..9099db0db4 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py @@ -8,6 +8,7 @@ Sequence, Tuple, Union, + cast, TYPE_CHECKING, ) @@ -22,6 +23,7 @@ LLMMetadata, MessageRole, TextBlock, + ToolCallBlock, ThinkingBlock, ) from llama_index.core.bridge.pydantic import Field, PrivateAttr @@ -38,7 +40,6 @@ stream_chat_to_completion_decorator, ) from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection -from llama_index.core.llms.utils import parse_partial_json from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.bedrock_converse.utils import ( bedrock_modelname_to_context_size, @@ -349,7 +350,7 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: def _get_content_and_tool_calls( self, response: Optional[Dict[str, Any]] = None, content: Dict[str, Any] = None ) -> Tuple[ - List[Union[TextBlock, ThinkingBlock]], Dict[str, Any], List[str], List[str] + List[Union[TextBlock, ToolCallBlock, ThinkingBlock]], List[str], List[str] ]: assert response is not None or content is not None, ( f"Either response or content must be provided. Got response: {response}, content: {content}" @@ -357,7 +358,7 @@ def _get_content_and_tool_calls( assert response is None or content is None, ( f"Only one of response or content should be provided. Got response: {response}, content: {content}" ) - tool_calls = [] + blocks: List[Union[TextBlock, ToolCallBlock, ThinkingBlock]] = [] tool_call_ids = [] status = [] blocks = [] @@ -381,19 +382,22 @@ def _get_content_and_tool_calls( ) ) if tool_usage := content_block.get("toolUse", None): - if "toolUseId" not in tool_usage: - tool_usage["toolUseId"] = content_block["toolUseId"] - if "name" not in tool_usage: - tool_usage["name"] = content_block["name"] - tool_calls.append(tool_usage) + blocks.append( + ToolCallBlock( + tool_name=tool_usage.get("name", None) + or content_block.get("name", ""), + tool_call_id=tool_usage.get("toolUseId", None) + or content_block.get("toolUseId", None), + tool_kwargs=tool_usage.get("input", {}), + ) + ) if tool_result := content_block.get("toolResult", None): for tool_result_content in tool_result["content"]: if text := tool_result_content.get("text", None): - text_content += text - tool_call_ids.append(tool_result_content.get("toolUseId", "")) + blocks.append(TextBlock(text=text)) + tool_call_ids.append(tool_result_content.get("toolUseId", "")) status.append(tool_result.get("status", "")) - - return blocks, tool_calls, tool_call_ids, status + return blocks, tool_call_ids, status @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: @@ -420,16 +424,13 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: **all_kwargs, ) - blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls( - response - ) + blocks, tool_call_ids, status = self._get_content_and_tool_calls(response) return ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, blocks=blocks, additional_kwargs={ - "tool_calls": tool_calls, "tool_call_id": tool_call_ids, "status": status, }, @@ -522,6 +523,16 @@ def gen() -> ChatResponseGen: current_tool_call = join_two_dicts( current_tool_call, tool_use_delta ) + else: + current_tool_call = tool_use_delta + + tool_calls.append( + ToolCallBlock( + tool_call_id=current_tool_call.get("toolUseId"), + tool_name=current_tool_call.get("name", ""), + tool_kwargs=current_tool_call.get("input", {}), + ) + ) blocks: List[Union[TextBlock, ThinkingBlock]] = [ TextBlock(text=content.get("text", "")) @@ -540,11 +551,13 @@ def gen() -> ChatResponseGen: yield ChatResponse( message=ChatMessage( role=role, - blocks=blocks, + blocks=[ + *blocks, + *tool_calls, + ], additional_kwargs={ - "tool_calls": tool_calls, "tool_call_id": [ - tc.get("toolUseId", "") for tc in tool_calls + tc.tool_call_id or "" for tc in tool_calls ], "status": [], # Will be populated when tool results come in }, @@ -560,7 +573,13 @@ def gen() -> ChatResponseGen: # Start tracking a new tool call current_tool_call = tool_use # Add to our list of tool calls - tool_calls.append(current_tool_call) + tool_calls.append( + ToolCallBlock( + tool_call_id=current_tool_call.get("toolUseId"), + tool_name=current_tool_call.get("name", ""), + tool_kwargs=current_tool_call.get("input", {}), + ) + ) blocks: List[Union[TextBlock, ThinkingBlock]] = [ TextBlock(text=content.get("text", "")) @@ -579,12 +598,12 @@ def gen() -> ChatResponseGen: yield ChatResponse( message=ChatMessage( role=role, - blocks=blocks, + blocks=[ + *blocks, + *tool_calls, + ], additional_kwargs={ - "tool_calls": tool_calls, - "tool_call_id": [ - tc.get("toolUseId", "") for tc in tool_calls - ], + "tool_call_id": [tc.tool_call_id for tc in tool_calls], "status": [], # Will be populated when tool results come in }, ), @@ -615,9 +634,11 @@ def gen() -> ChatResponseGen: yield ChatResponse( message=ChatMessage( role=role, - blocks=blocks, + blocks=[ + *blocks, + *tool_calls, + ], additional_kwargs={ - "tool_calls": tool_calls, "tool_call_id": [ tc.get("toolUseId", "") for tc in tool_calls ], @@ -667,16 +688,13 @@ async def achat( **all_kwargs, ) - blocks, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls( - response - ) + blocks, tool_call_ids, status = self._get_content_and_tool_calls(response) return ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, blocks=blocks, additional_kwargs={ - "tool_calls": tool_calls, "tool_call_id": tool_call_ids, "status": status, }, @@ -771,6 +789,16 @@ async def gen() -> ChatResponseAsyncGen: current_tool_call = join_two_dicts( current_tool_call, tool_use_delta ) + else: + current_tool_call = tool_use_delta + + tool_calls.append( + ToolCallBlock( + tool_call_id=current_tool_call.get("toolUseId"), + tool_name=current_tool_call.get("name", ""), + tool_kwargs=current_tool_call.get("input", {}), + ) + ) blocks: List[Union[TextBlock, ThinkingBlock]] = [ TextBlock(text=content.get("text", "")) ] @@ -788,12 +816,12 @@ async def gen() -> ChatResponseAsyncGen: yield ChatResponse( message=ChatMessage( role=role, - blocks=blocks, + blocks=[ + *blocks, + *tool_calls, + ], additional_kwargs={ - "tool_calls": tool_calls, - "tool_call_id": [ - tc.get("toolUseId", "") for tc in tool_calls - ], + "tool_call_id": [tc.tool_call_id for tc in tool_calls], "status": [], # Will be populated when tool results come in }, ), @@ -808,7 +836,13 @@ async def gen() -> ChatResponseAsyncGen: # Start tracking a new tool call current_tool_call = tool_use # Add to our list of tool calls - tool_calls.append(current_tool_call) + tool_calls.append( + ToolCallBlock( + tool_call_id=current_tool_call.get("toolUseId"), + tool_name=current_tool_call.get("name", ""), + tool_kwargs=current_tool_call.get("input", {}), + ) + ) blocks: List[Union[TextBlock, ThinkingBlock]] = [ TextBlock(text=content.get("text", "")) @@ -827,9 +861,11 @@ async def gen() -> ChatResponseAsyncGen: yield ChatResponse( message=ChatMessage( role=role, - blocks=blocks, + blocks=[ + *blocks, + *tool_calls, + ], additional_kwargs={ - "tool_calls": tool_calls, "tool_call_id": [ tc.get("toolUseId", "") for tc in tool_calls ], @@ -863,11 +899,13 @@ async def gen() -> ChatResponseAsyncGen: yield ChatResponse( message=ChatMessage( role=role, - blocks=blocks, + blocks=[ + *blocks, + *tool_calls, + ], additional_kwargs={ - "tool_calls": tool_calls, "tool_call_id": [ - tc.get("toolUseId", "") for tc in tool_calls + tc.tool_call_id for tc in tool_calls ], "status": [], }, @@ -941,7 +979,11 @@ def get_tool_calls_from_response( **kwargs: Any, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: @@ -953,26 +995,13 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: - if "toolUseId" not in tool_call or "name" not in tool_call: - raise ValueError("Invalid tool call.") - # handle empty inputs - argument_dict = {} - if "input" in tool_call and isinstance(tool_call["input"], str): - # TODO parse_partial_json is not perfect - try: - argument_dict = parse_partial_json(tool_call["input"]) - except ValueError: - argument_dict = {} - elif "input" in tool_call and isinstance(tool_call["input"], dict): - argument_dict = tool_call["input"] - else: - continue + argument_dict = cast(Dict[str, Any], tool_call.tool_kwargs) or {} tool_selections.append( ToolSelection( - tool_id=tool_call["toolUseId"], - tool_name=tool_call["name"], + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, tool_kwargs=argument_dict, ) ) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py index 6dfe01fa94..e2e0bf0fb3 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py @@ -11,6 +11,7 @@ CompletionResponse, ImageBlock, TextBlock, + ToolCallBlock, ThinkingBlock, CachePoint, CacheControl, @@ -254,7 +255,6 @@ def test_complete(bedrock_converse): assert response.text == EXP_RESPONSE assert response.additional_kwargs["status"] == [] assert response.additional_kwargs["tool_call_id"] == [] - assert response.additional_kwargs["tool_calls"] == [] def test_stream_chat(bedrock_converse): @@ -335,7 +335,6 @@ async def test_acomplete(bedrock_converse): assert response.text == EXP_RESPONSE assert response.additional_kwargs["status"] == [] assert response.additional_kwargs["tool_call_id"] == [] - assert response.additional_kwargs["tool_calls"] == [] @pytest.mark.asyncio @@ -846,6 +845,62 @@ async def test_bedrock_converse_agent_with_void_tool_and_continued_conversation( assert len(str(response_5)) > 0 +@needs_aws_creds +@pytest.mark.asyncio +async def test_bedrock_converse_tool_calling(bedrock_converse_integration): + tool = FunctionTool.from_defaults( + fn=get_temperature, + name="get_temperature", + description="Get the temperature of a location (str) in Celsius degree", + ) + response = bedrock_converse_integration.chat_with_tools( + tools=[tool], + user_msg="What is the temperature in San Francisco?", + tool_required=True, + ) + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + >= 1 + ) + assert any( + block.tool_name == "get_temperature" + for block in [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + response = await bedrock_converse_integration.achat_with_tools( + tools=[tool], + user_msg="What is the temperature in San Francisco?", + tool_required=True, + ) + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + >= 1 + ) + assert any( + block.tool_name == "get_temperature" + for block in [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + + @needs_aws_creds @pytest.mark.asyncio async def test_bedrock_converse_thinking(bedrock_converse_integration_thinking): diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py index 69e5ad4055..c31044e557 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py @@ -16,6 +16,7 @@ Type, Union, Callable, + cast, ) @@ -37,6 +38,7 @@ LLMMetadata, MessageRole, ThinkingBlock, + ToolCallBlock, TextBlock, ) from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr @@ -376,7 +378,6 @@ def _stream_chat( def gen() -> ChatResponseGen: content = "" - existing_tool_calls = [] thoughts = "" for r in response: if not r.candidates: @@ -390,13 +391,9 @@ def gen() -> ChatResponseGen: else: content += content_delta llama_resp = chat_from_gemini_response(r) - existing_tool_calls.extend( - llama_resp.message.additional_kwargs.get("tool_calls", []) - ) llama_resp.delta = content_delta - llama_resp.message.blocks = [TextBlock(text=content)] + llama_resp.message.blocks += [TextBlock(text=content)] llama_resp.message.blocks.append(ThinkingBlock(content=thoughts)) - llama_resp.message.additional_kwargs["tool_calls"] = existing_tool_calls yield llama_resp if self.use_file_api: @@ -429,7 +426,6 @@ async def _astream_chat( async def gen() -> ChatResponseAsyncGen: content = "" - existing_tool_calls = [] thoughts = "" async for r in await chat.send_message_stream( next_msg.parts if isinstance(next_msg, types.Content) else next_msg @@ -448,19 +444,11 @@ async def gen() -> ChatResponseAsyncGen: else: content += content_delta llama_resp = chat_from_gemini_response(r) - existing_tool_calls.extend( - llama_resp.message.additional_kwargs.get( - "tool_calls", [] - ) - ) - llama_resp.delta = content_delta - llama_resp.message.blocks = [TextBlock(text=content)] + llama_resp.message.blocks += [TextBlock(text=content)] llama_resp.message.blocks.append( ThinkingBlock(content=thoughts) ) - llama_resp.message.additional_kwargs["tool_calls"] = ( - existing_tool_calls - ) + llama_resp.delta = content_delta yield llama_resp if self.use_file_api: @@ -551,7 +539,11 @@ def get_tool_calls_from_response( **kwargs: Any, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: @@ -565,9 +557,9 @@ def get_tool_calls_from_response( for tool_call in tool_calls: tool_selections.append( ToolSelection( - tool_id=tool_call["name"], - tool_name=tool_call["name"], - tool_kwargs=tool_call["args"], + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, + tool_kwargs=cast(Dict[str, Any], tool_call.tool_kwargs), ) ) diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py index fe02a5a69b..56a2b090dd 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py @@ -2,15 +2,7 @@ import logging from collections.abc import Sequence from io import BytesIO -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Union, - Optional, - Type, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, Union, Optional, Type, Tuple, cast import typing import google.genai.types as types @@ -29,6 +21,7 @@ DocumentBlock, VideoBlock, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.program.utils import _repair_incomplete_json from tenacity import ( @@ -188,15 +181,12 @@ def chat_from_gemini_response( ) additional_kwargs["thought_signatures"].append(part.thought_signature) if part.function_call: - if "tool_calls" not in additional_kwargs: - additional_kwargs["tool_calls"] = [] - additional_kwargs["tool_calls"].append( - { - "id": part.function_call.id if part.function_call.id else "", - "name": part.function_call.name, - "args": part.function_call.args, - "thought_signature": part.thought_signature, - } + content_blocks.append( + ToolCallBlock( + tool_call_id=part.function_call.id, + tool_name=part.function_call.name or "", + tool_kwargs=part.function_call.args, + ) ) if thought_tokens: thinking_blocks = [ @@ -326,6 +316,10 @@ async def chat_message_to_gemini( part.thought_signature = block.additional_information.get( "thought_signature", None ) + elif isinstance(block, ToolCallBlock): + part = types.Part.from_function_call( + name=block.tool_name or "", args=cast(Dict[str, Any], block.tool_kwargs) + ) else: msg = f"Unsupported content block type: {type(block).__name__}" raise ValueError(msg) @@ -341,22 +335,11 @@ async def chat_message_to_gemini( ) parts.append(part) - for tool_call in message.additional_kwargs.get("tool_calls", []): - if isinstance(tool_call, dict): - part = types.Part.from_function_call( - name=tool_call.get("name"), args=tool_call.get("args") - ) - part.thought_signature = tool_call.get("thought_signature") - else: - part = types.Part.from_function_call( - name=tool_call.name, args=tool_call.args - ) - part.thought_signature = tool_call.thought_signature - parts.append(part) - # the tool call id is the name of the tool # the tool call response is the content of the message, overriding the existing content # (the only content before this should be the tool call) + # we do not use ToolCallBlock here because the message that gets returned + # with the result of the query already has the 'Tool' role if message.additional_kwargs.get("tool_call_id"): function_response_part = types.Part.from_function_response( name=message.additional_kwargs.get("tool_call_id"), diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py b/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py index 6be1799fc3..bc5f29d537 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py @@ -11,6 +11,7 @@ TextBlock, VideoBlock, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.llms.llm import ToolSelection from llama_index.core.program.function_program import get_function_tool @@ -564,8 +565,17 @@ def test_tool_required_integration(llm: GoogleGenAI) -> None: tools=[search_tool], tool_required=True, ) - assert response.message.additional_kwargs.get("tool_calls") is not None - assert len(response.message.additional_kwargs["tool_calls"]) > 0 + + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + > 0 + ) # Test with tool_required=False response = llm.chat_with_tools( @@ -730,15 +740,13 @@ async def test_prepare_chat_params_more_than_2_tool_calls(): ], ), ChatMessage( - content="Let me search for puppies.", + blocks=[ + TextBlock(text="Let me search for puppies."), + ToolCallBlock(tool_name="tool_1"), + ToolCallBlock(tool_name="tool_2"), + ToolCallBlock(tool_name="tool_3"), + ], role=MessageRole.ASSISTANT, - additional_kwargs={ - "tool_calls": [ - {"name": "tool_1"}, - {"name": "tool_2"}, - {"name": "tool_3"}, - ] - }, ), ChatMessage( content="Tool 1 Response", @@ -778,9 +786,9 @@ async def test_prepare_chat_params_more_than_2_tool_calls(): thought=True, ), types.Part(text="Let me search for puppies."), - types.Part.from_function_call(name="tool_1", args=None), - types.Part.from_function_call(name="tool_2", args=None), - types.Part.from_function_call(name="tool_3", args=None), + types.Part.from_function_call(name="tool_1", args={}), + types.Part.from_function_call(name="tool_2", args={}), + types.Part.from_function_call(name="tool_3", args={}), ], role=MessageRole.MODEL, ), @@ -872,6 +880,7 @@ def test_cached_content_in_response() -> None: mock_response.candidates[0].content.parts[0].text = "Test response" mock_response.candidates[0].content.parts[0].thought = False mock_response.candidates[0].content.parts[0].inline_data = None + mock_response.candidates[0].content.parts[0].function_call = None mock_response.prompt_feedback = None mock_response.usage_metadata = None mock_response.function_calls = None @@ -899,6 +908,7 @@ def test_cached_content_without_cached_content() -> None: mock_response.candidates[0].content.parts[0].text = "Test response" mock_response.candidates[0].content.parts[0].thought = False mock_response.candidates[0].content.parts[0].inline_data = None + mock_response.candidates[0].content.parts[0].function_call = None mock_response.prompt_feedback = None mock_response.usage_metadata = None mock_response.function_calls = None @@ -923,9 +933,11 @@ def test_thoughts_in_response() -> None: mock_response.candidates[0].content.parts[0].text = "This is a thought." mock_response.candidates[0].content.parts[0].inline_data = None mock_response.candidates[0].content.parts[0].thought = True + mock_response.candidates[0].content.parts[0].function_call = None mock_response.candidates[0].content.parts[1].text = "This is not a thought." mock_response.candidates[0].content.parts[1].inline_data = None mock_response.candidates[0].content.parts[1].thought = None + mock_response.candidates[0].content.parts[1].function_call = None mock_response.candidates[0].content.parts[0].model_dump = MagicMock(return_value={}) mock_response.candidates[0].content.parts[1].model_dump = MagicMock(return_value={}) mock_response.prompt_feedback = None @@ -967,6 +979,7 @@ def test_thoughts_without_thought_response() -> None: mock_response.candidates[0].content.parts[0].text = "This is not a thought." mock_response.candidates[0].content.parts[0].inline_data = None mock_response.candidates[0].content.parts[0].thought = None + mock_response.candidates[0].content.parts[0].function_call = None mock_response.prompt_feedback = None mock_response.usage_metadata = None mock_response.function_calls = None @@ -1084,6 +1097,7 @@ def test_built_in_tool_in_response() -> None: ].text = "Test response with search results" mock_response.candidates[0].content.parts[0].inline_data = None mock_response.candidates[0].content.parts[0].thought = None + mock_response.candidates[0].content.parts[0].function_call = None mock_response.prompt_feedback = None mock_response.usage_metadata = MagicMock() mock_response.usage_metadata.model_dump.return_value = { @@ -1523,6 +1537,7 @@ def test_code_execution_response_parts() -> None: ) mock_text_part.inline_data = None mock_text_part.thought = None + mock_text_part.function_call = None mock_code_part = MagicMock() mock_code_part.text = None @@ -1532,11 +1547,13 @@ def test_code_execution_response_parts() -> None: "code": "def is_prime(n):\n if n < 2:\n return False\n for i in range(2, int(n**0.5) + 1):\n if n % i == 0:\n return False\n return True\n\nprimes = []\nn = 2\nwhile len(primes) < 50:\n if is_prime(n):\n primes.append(n)\n n += 1\n\nprint(f'Sum of first 50 primes: {sum(primes)}')", "language": types.Language.PYTHON, } + mock_code_part.function_call = None mock_result_part = MagicMock() mock_result_part.text = None mock_result_part.inline_data = None mock_result_part.thought = None + mock_result_part.function_call = None mock_result_part.code_execution_result = { "outcome": types.Outcome.OUTCOME_OK, "output": "Sum of first 50 primes: 5117", @@ -1545,6 +1562,7 @@ def test_code_execution_response_parts() -> None: mock_final_text_part = MagicMock() mock_final_text_part.text = "The sum of the first 50 prime numbers is 5117." mock_final_text_part.inline_data = None + mock_final_text_part.function_call = None mock_final_text_part.thought = None mock_candidate.content.parts = [ diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai_vertex.py b/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai_vertex.py index 9262364e5d..bd731e1c22 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai_vertex.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai_vertex.py @@ -144,6 +144,7 @@ def test_cached_content_in_response_vertexai() -> None: mock_response.candidates[0].content.parts[0].text = "Test response" mock_response.candidates[0].content.parts[0].inline_data = None mock_response.candidates[0].content.parts[0].thought = False + mock_response.candidates[0].content.parts[0].function_call = None mock_response.prompt_feedback = None mock_response.usage_metadata = None mock_response.function_calls = None @@ -171,6 +172,7 @@ def test_cached_content_without_cached_content_vertexai() -> None: mock_response.candidates[0].content.parts[0].text = "Test response" mock_response.candidates[0].content.parts[0].inline_data = None mock_response.candidates[0].content.parts[0].thought = False + mock_response.candidates[0].content.parts[0].function_call = None mock_response.prompt_feedback = None mock_response.usage_metadata = None mock_response.function_calls = None diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py index 7cf54d7682..f6033ed534 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py @@ -1,4 +1,3 @@ -import json from typing import ( Any, Callable, @@ -9,6 +8,7 @@ Tuple, Union, TYPE_CHECKING, + cast, ) from llama_index.core.base.llms.types import ( @@ -25,6 +25,7 @@ TextBlock, ImageBlock, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.callbacks import CallbackManager @@ -64,6 +65,8 @@ ImageURLChunk, ContentChunk, ThinkChunk, + ToolCall, + FunctionCall, ) if TYPE_CHECKING: @@ -79,6 +82,16 @@ def to_mistral_chunks(content_blocks: Sequence[ContentBlock]) -> Sequence[Conten for content_block in content_blocks: if isinstance(content_block, TextBlock): content_chunks.append(TextChunk(text=content_block.text)) + elif isinstance(content_block, ToolCallBlock): + content_chunks.append( + ToolCall( + function=FunctionCall( + name=content_block.tool_name, + arguments=cast(Dict[str, Any], content_block.tool_kwargs), + ), + id=content_block.tool_call_id or "", + ) + ) elif isinstance(content_block, ThinkingBlock): if content_block.content: content_chunks.append( @@ -112,12 +125,11 @@ def to_mistral_chatmessage( ) -> List[Messages]: new_messages = [] for m in messages: - tool_calls = m.additional_kwargs.get("tool_calls") chunks = to_mistral_chunks(m.blocks) if m.role == MessageRole.USER: new_messages.append(UserMessage(content=chunks)) elif m.role == MessageRole.ASSISTANT: - new_messages.append(AssistantMessage(content=chunks, tool_calls=tool_calls)) + new_messages.append(AssistantMessage(content=chunks)) elif m.role == MessageRole.SYSTEM: new_messages.append(SystemMessage(content=chunks)) elif m.role == MessageRole.TOOL or m.role == MessageRole.FUNCTION: @@ -135,9 +147,15 @@ def to_mistral_chatmessage( def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block for block in response.message.blocks if isinstance(block, ToolCallBlock) + ] if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + response.message.blocks = [ + block + for block in response.message.blocks + if not isinstance(block, ToolCallBlock) + ] + [tool_calls[0]] class MistralAI(FunctionCallingLLM): @@ -315,7 +333,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: messages = to_mistral_chatmessage(messages) all_kwargs = self._get_all_kwargs(**kwargs) response = self._client.chat.complete(messages=messages, **all_kwargs) - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] additional_kwargs = {} if self.model in MISTRAL_AI_REASONING_MODELS: @@ -336,7 +354,19 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: blocks.append(TextBlock(text=response_txt)) tool_calls = response.choices[0].message.tool_calls if tool_calls is not None: - additional_kwargs["tool_calls"] = tool_calls + for tool_call in tool_calls: + if isinstance(tool_call, ToolCall): + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=tool_call.function.arguments, + ) + ) + else: + blocks.append( + ToolCallBlock(tool_name=tool_call[0], tool_kwargs=tool_call[1]) + ) return ChatResponse( message=ChatMessage( @@ -367,7 +397,7 @@ def stream_chat( def gen() -> ChatResponseGen: content = "" - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] for chunk in response: delta = chunk.data.choices[0].delta role = delta.role or MessageRole.ASSISTANT @@ -375,7 +405,14 @@ def gen() -> ChatResponseGen: # NOTE: Unlike openAI, we are directly injecting the tool calls additional_kwargs = {} if delta.tool_calls: - additional_kwargs["tool_calls"] = delta.tool_calls + for tool_call in delta.tool_calls: + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=tool_call.function.arguments, + ) + ) content_delta = delta.content or "" content += content_delta @@ -425,7 +462,7 @@ async def achat( messages=messages, **all_kwargs ) - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] additional_kwargs = {} if self.model in MISTRAL_AI_REASONING_MODELS: thinking_txt, response_txt = self._separate_thinking( @@ -446,7 +483,19 @@ async def achat( tool_calls = response.choices[0].message.tool_calls if tool_calls is not None: - additional_kwargs["tool_calls"] = tool_calls + for tool_call in tool_calls: + if isinstance(tool_call, ToolCall): + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=tool_call.function.arguments, + ) + ) + else: + blocks.append( + ToolCallBlock(tool_name=tool_call[0], tool_kwargs=tool_call[1]) + ) return ChatResponse( message=ChatMessage( @@ -477,14 +526,21 @@ async def astream_chat( async def gen() -> ChatResponseAsyncGen: content = "" - blocks: List[ThinkingBlock | TextBlock] = [] + blocks: List[ThinkingBlock | TextBlock | ToolCallBlock] = [] async for chunk in response: delta = chunk.data.choices[0].delta role = delta.role or MessageRole.ASSISTANT # NOTE: Unlike openAI, we are directly injecting the tool calls additional_kwargs = {} if delta.tool_calls: - additional_kwargs["tool_calls"] = delta.tool_calls + for tool_call in delta.tool_calls: + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=tool_call.function.arguments, + ) + ) content_delta = delta.content or "" content += content_delta @@ -570,7 +626,11 @@ def get_tool_calls_from_response( error_on_no_tool_call: bool = True, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: @@ -585,13 +645,13 @@ def get_tool_calls_from_response( if not isinstance(tool_call, ToolCall): raise ValueError("Invalid tool_call object") - argument_dict = json.loads(tool_call.function.arguments) + argument_dict = tool_call.tool_kwargs tool_selections.append( ToolSelection( - tool_id=tool_call.id, - tool_name=tool_call.function.name, - tool_kwargs=argument_dict, + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, + tool_kwargs=cast(Dict[str, Any], argument_dict), ) ) diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py index ce4ed450aa..e385839a0e 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py @@ -3,13 +3,14 @@ import base64 from pathlib import Path from unittest.mock import patch +from typing import cast, Any from mistralai import ToolCall, ImageURLChunk, TextChunk, ThinkChunk import pytest from llama_index.core.base.llms.base import BaseLLM from llama_index.core.llms import ChatMessage, ImageBlock, TextBlock -from llama_index.core.base.llms.types import ThinkingBlock +from llama_index.core.base.llms.types import ThinkingBlock, ToolCallBlock from llama_index.core.tools import FunctionTool from llama_index.llms.mistralai import MistralAI from llama_index.llms.mistralai.base import to_mistral_chunks @@ -40,14 +41,13 @@ def test_tool_required(): user_msg="What is the capital of France?", tool_required=True, ) - additional_kwargs = result.message.additional_kwargs - assert "tool_calls" in additional_kwargs - tool_calls = additional_kwargs["tool_calls"] + tool_calls = [ + block for block in result.message.blocks if isinstance(block, ToolCallBlock) + ] assert len(tool_calls) == 1 tool_call = tool_calls[0] - assert isinstance(tool_call, ToolCall) - assert tool_call.function.name == "search_tool" - assert "query" in tool_call.function.arguments + assert tool_call.tool_name == "search_tool" + assert "query" in cast(dict[str, Any], tool_call.tool_kwargs) @patch("mistralai.Mistral") @@ -184,3 +184,19 @@ def test_to_mistral_chunks(tmp_path: Path, image_url: str) -> None: ) assert isinstance(thinking_chunks[1], TextChunk) assert thinking_chunks[1].text == "This is some text" + tool_call_blocks = [ + ToolCallBlock( + tool_call_id="1", tool_kwargs={"a": 1, "b": 2}, tool_name="sum_tool" + ), + TextBlock(text="The result of 1+2 is 3"), + ] + tool_call_chunks = to_mistral_chunks(tool_call_blocks) + assert len(tool_call_blocks) == 2 + assert isinstance(tool_call_chunks[0], ToolCall) + assert ( + tool_call_chunks[0].function.name == "sum_tool" + and tool_call_chunks[0].function.arguments == {"a": 1, "b": 2} + and tool_call_chunks[0].id == "1" + ) + assert isinstance(tool_call_chunks[1], TextChunk) + assert tool_call_chunks[1].text == "The result of 1+2 is 3" diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py index 9df88a3b5a..ba121c30a7 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py +++ b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py @@ -10,6 +10,7 @@ Tuple, Type, Union, + cast, ) from ollama import AsyncClient, Client @@ -32,6 +33,7 @@ LLMMetadata, MessageRole, TextBlock, + ToolCallBlock, ThinkingBlock, ) from llama_index.core.bridge.pydantic import Field, PrivateAttr @@ -58,9 +60,15 @@ def get_additional_kwargs( def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) or [] + tool_calls = [ + block for block in response.message.blocks if isinstance(block, ToolCallBlock) + ] if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + response.message.blocks = [ + block + for block in response.message.blocks + if not isinstance(block, ToolCallBlock) + ] + [tool_calls[0]] class Ollama(FunctionCallingLLM): @@ -237,6 +245,17 @@ def _convert_to_ollama_messages(self, messages: Sequence[ChatMessage]) -> Dict: cur_ollama_message["images"].append( block.resolve_image(as_base64=True).read().decode("utf-8") ) + elif isinstance(block, ToolCallBlock): + if "tool_calls" not in cur_ollama_message: + cur_ollama_message["tool_calls"] = [] + cur_ollama_message["tool_calls"].append( + { + "function": { + "name": block.tool_name, + "arguments": block.tool_kwargs, + } + } + ) elif isinstance(block, ThinkingBlock): if block.content: cur_ollama_message["thinking"] = block.content @@ -312,7 +331,11 @@ def get_tool_calls_from_response( error_on_no_tool_call: bool = True, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) or [] + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: raise ValueError( @@ -323,13 +346,13 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: - argument_dict = tool_call["function"]["arguments"] + argument_dict = cast(Dict[str, Any], tool_call.tool_kwargs) tool_selections.append( ToolSelection( # tool ids not provided by Ollama - tool_id=tool_call["function"]["name"], - tool_name=tool_call["function"]["name"], + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, tool_kwargs=argument_dict, ) ) @@ -357,9 +380,16 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: response = dict(response) - blocks: List[TextBlock | ThinkingBlock] = [] - + blocks: list[TextBlock | ToolCallBlock | ThinkingBlock] = [] tool_calls = response["message"].get("tool_calls", []) or [] + blocks.append(TextBlock(text=response["message"].get("content", ""))) + for tool_call in tool_calls: + blocks.append( + ToolCallBlock( + tool_name=tool_call["function"]["name"], + tool_kwargs=tool_call["function"]["arguments"], + ) + ) thinking = response["message"].get("thinking", None) if thinking: blocks.append(ThinkingBlock(content=thinking)) @@ -373,7 +403,6 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: message=ChatMessage( blocks=blocks, role=response["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={"tool_calls": tool_calls}, ), raw=response, ) @@ -402,8 +431,7 @@ def gen() -> ChatResponseGen: response_txt = "" thinking_txt = "" - seen_tool_calls = set() - all_tool_calls = [] + blocks: List[TextBlock | ToolCallBlock | ThinkingBlock] = [] for r in response: if r["message"]["content"] is None: @@ -415,34 +443,25 @@ def gen() -> ChatResponseGen: thinking_txt += r["message"].get("thinking", "") or "" new_tool_calls = [dict(t) for t in r["message"].get("tool_calls") or []] + blocks.append(TextBlock(text=response_txt)) for tool_call in new_tool_calls: - if ( - str(tool_call["function"]["name"]), - str(tool_call["function"]["arguments"]), - ) in seen_tool_calls: - continue - seen_tool_calls.add( - ( - str(tool_call["function"]["name"]), - str(tool_call["function"]["arguments"]), + blocks.append( + ToolCallBlock( + tool_name=tool_call["function"]["name"], + tool_kwargs=tool_call["function"]["arguments"], ) ) - all_tool_calls.append(tool_call) token_counts = self._get_response_token_counts(r) if token_counts: r["usage"] = token_counts - output_blocks = [TextBlock(text=response_txt)] if thinking_txt: - output_blocks.insert(0, ThinkingBlock(content=thinking_txt)) + blocks.insert(0, ThinkingBlock(content=thinking_txt)) yield ChatResponse( message=ChatMessage( - blocks=output_blocks, + blocks=blocks, role=r["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={ - "tool_calls": all_tool_calls, - }, ), delta=r["message"].get("content", ""), raw=r, @@ -477,8 +496,7 @@ async def gen() -> ChatResponseAsyncGen: response_txt = "" thinking_txt = "" - seen_tool_calls = set() - all_tool_calls = [] + blocks: List[TextBlock | ToolCallBlock | ThinkingBlock] = [] async for r in response: if r["message"]["content"] is None: @@ -490,34 +508,25 @@ async def gen() -> ChatResponseAsyncGen: thinking_txt += r["message"].get("thinking", "") or "" new_tool_calls = [dict(t) for t in r["message"].get("tool_calls") or []] + blocks.append(TextBlock(text=response_txt)) for tool_call in new_tool_calls: - if ( - str(tool_call["function"]["name"]), - str(tool_call["function"]["arguments"]), - ) in seen_tool_calls: - continue - seen_tool_calls.add( - ( - str(tool_call["function"]["name"]), - str(tool_call["function"]["arguments"]), + blocks.append( + ToolCallBlock( + tool_name=tool_call["function"]["name"], + tool_kwargs=tool_call["function"]["arguments"], ) ) - all_tool_calls.append(tool_call) token_counts = self._get_response_token_counts(r) if token_counts: r["usage"] = token_counts - output_blocks = [TextBlock(text=response_txt)] if thinking_txt: - output_blocks.insert(0, ThinkingBlock(content=thinking_txt)) + blocks.insert(0, ThinkingBlock(content=thinking_txt)) yield ChatResponse( message=ChatMessage( - blocks=output_blocks, + blocks=blocks, role=r["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={ - "tool_calls": all_tool_calls, - }, ), delta=r["message"].get("content", ""), raw=r, @@ -551,12 +560,19 @@ async def achat( response = dict(response) - blocks: List[TextBlock | ThinkingBlock] = [] - + blocks: list[TextBlock | ToolCallBlock | ThinkingBlock] = [] tool_calls = response["message"].get("tool_calls", []) or [] + blocks.append(TextBlock(text=response["message"].get("content", ""))) + for tool_call in tool_calls: + blocks.append( + ToolCallBlock( + tool_name=tool_call["function"]["name"], + tool_kwargs=tool_call["function"]["arguments"], + ) + ) thinking = response["message"].get("thinking", None) if thinking: - blocks.append(ThinkingBlock(content=thinking)) + blocks.insert(0, ThinkingBlock(content=thinking)) blocks.append(TextBlock(text=response["message"].get("content", ""))) token_counts = self._get_response_token_counts(response) if token_counts: @@ -566,7 +582,6 @@ async def achat( message=ChatMessage( blocks=blocks, role=response["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={"tool_calls": tool_calls}, ), raw=response, ) diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py index 600bb4888c..bc475aa5be 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py +++ b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py @@ -6,12 +6,13 @@ from llama_index.core.base.llms.types import ThinkingBlock, TextBlock from llama_index.core.base.llms.base import BaseLLM +from llama_index.core.base.llms.types import ToolCallBlock from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.llms import ChatMessage from llama_index.core.tools import FunctionTool from llama_index.llms.ollama import Ollama -test_model = os.environ.get("OLLAMA_TEST_MODEL", "llama3.1:latest") +test_model = os.environ.get("OLLAMA_TEST_MODEL", "qwen3:0.6b") thinking_test_model = os.environ.get("OLLAMA_THINKING_TEST_MODEL", "qwen3:0.6b") try: @@ -330,7 +331,16 @@ def test_chat_with_tools_returns_empty_array_if_no_tools_were_called() -> None: ], ) - assert response.message.additional_kwargs.get("tool_calls", []) == [] + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 0 + ) tool_calls = llm.get_tool_calls_from_response(response, error_on_no_tool_call=False) assert len(tool_calls) == 0 @@ -358,4 +368,13 @@ async def test_async_chat_with_tools_returns_empty_array_if_no_tools_were_called ChatMessage(role="user", content="Hello, how are you?"), ], ) - assert response.message.additional_kwargs.get("tool_calls", []) == [] + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 0 + ) diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py index 91ed29cc00..69a48cd0f8 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py @@ -43,6 +43,8 @@ CompletionResponseGen, LLMMetadata, MessageRole, + ToolCallBlock, + TextBlock, ) from llama_index.core.bridge.pydantic import ( Field, @@ -121,9 +123,15 @@ def encode(self, text: str) -> List[int]: # fmt: skip def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block for block in response.message.blocks if isinstance(block, ToolCallBlock) + ] if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + response.message.blocks = [ + block + for block in response.message.blocks + if not isinstance(block, ToolCallBlock) + ] + [tool_calls[0]] class OpenAI(FunctionCallingLLM): @@ -528,6 +536,7 @@ def gen() -> ChatResponseGen: messages=message_dicts, **self._get_model_kwargs(stream=True, **kwargs), ): + blocks = [] response = cast(ChatCompletionChunk, response) if len(response.choices) > 0: delta = response.choices[0].delta @@ -545,17 +554,27 @@ def gen() -> ChatResponseGen: role = delta.role or MessageRole.ASSISTANT content_delta = delta.content or "" content += content_delta + blocks.append(TextBlock(text=content)) additional_kwargs = {} if is_function: tool_calls = update_tool_calls(tool_calls, delta.tool_calls) if tool_calls: additional_kwargs["tool_calls"] = tool_calls + for tool_call in tool_calls: + if tool_call.function: + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_kwargs=tool_call.function.arguments or {}, + tool_name=tool_call.function.name or "", + ) + ) yield ChatResponse( message=ChatMessage( role=role, - content=content, + blocks=blocks, additional_kwargs=additional_kwargs, ), delta=content_delta, @@ -785,6 +804,7 @@ async def gen() -> ChatResponseAsyncGen: messages=message_dicts, **self._get_model_kwargs(stream=True, **kwargs), ): + blocks = [] response = cast(ChatCompletionChunk, response) if len(response.choices) > 0: # check if the first chunk has neither content nor tool_calls @@ -812,17 +832,27 @@ async def gen() -> ChatResponseAsyncGen: role = delta.role or MessageRole.ASSISTANT content_delta = delta.content or "" content += content_delta + blocks.append(TextBlock(text=content)) additional_kwargs = {} if is_function: tool_calls = update_tool_calls(tool_calls, delta.tool_calls) if tool_calls: additional_kwargs["tool_calls"] = tool_calls + for tool_call in tool_calls: + if tool_call.function: + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_kwargs=tool_call.function.arguments or {}, + tool_name=tool_call.function.name or "", + ) + ) yield ChatResponse( message=ChatMessage( role=role, - content=content, + blocks=blocks, additional_kwargs=additional_kwargs, ), delta=content_delta, @@ -960,36 +990,71 @@ def get_tool_calls_from_response( **kwargs: Any, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) - - if len(tool_calls) < 1: - if error_on_no_tool_call: - raise ValueError( - f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + if tool_calls: + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + # this should handle both complete and partial jsons + try: + if isinstance(tool_call.tool_kwargs, str): + argument_dict = parse_partial_json(tool_call.tool_kwargs) + else: + argument_dict = tool_call.tool_kwargs + except (ValueError, TypeError, JSONDecodeError): + argument_dict = {} + + tool_selections.append( + ToolSelection( + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, + tool_kwargs=argument_dict, + ) ) - else: - return [] - tool_selections = [] - for tool_call in tool_calls: - if tool_call.type != "function": - raise ValueError("Invalid tool type. Unsupported by OpenAI llm") + return tool_selections + else: # keep it backward-compatible + tool_calls = response.message.additional_kwargs.get("tool_calls", []) - # this should handle both complete and partial jsons - try: - argument_dict = parse_partial_json(tool_call.function.arguments) - except (ValueError, TypeError, JSONDecodeError): - argument_dict = {} - - tool_selections.append( - ToolSelection( - tool_id=tool_call.id, - tool_name=tool_call.function.name, - tool_kwargs=argument_dict, + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + if tool_call.type != "function": + raise ValueError("Invalid tool type. Unsupported by OpenAI llm") + + # this should handle both complete and partial jsons + try: + argument_dict = parse_partial_json(tool_call.function.arguments) + except (ValueError, TypeError, JSONDecodeError): + argument_dict = {} + + tool_selections.append( + ToolSelection( + tool_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=argument_dict, + ) ) - ) - return tool_selections + return tool_selections def _prepare_schema( self, llm_kwargs: Optional[Dict[str, Any]], output_cls: Type[Model] diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/responses.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/responses.py index d126b5af1a..1ebf7b761a 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/responses.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/responses.py @@ -44,6 +44,7 @@ Type, Union, runtime_checkable, + cast, ) import llama_index.core.instrumentation as instrument @@ -67,6 +68,7 @@ TextBlock, ImageBlock, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.bridge.pydantic import ( Field, @@ -131,9 +133,15 @@ def encode(self, text: str) -> List[int]: # fmt: skip def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block for block in response.message.blocks if isinstance(block, ToolCallBlock) + ] if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + response.message.blocks = [ + block + for block in response.message.blocks + if not isinstance(block, ToolCallBlock) + ] + [tool_calls[0]] class OpenAIResponses(FunctionCallingLLM): @@ -454,7 +462,6 @@ def stream_complete( def _parse_response_output(output: List[ResponseOutputItem]) -> ChatResponse: message = ChatMessage(role=MessageRole.ASSISTANT, blocks=[]) additional_kwargs = {"built_in_tool_calls": []} - tool_calls = [] blocks: List[ContentBlock] = [] for item in output: if isinstance(item, ResponseOutputMessage): @@ -481,7 +488,13 @@ def _parse_response_output(output: List[ResponseOutputItem]) -> ChatResponse: elif isinstance(item, ResponseFileSearchToolCall): additional_kwargs["built_in_tool_calls"].append(item) elif isinstance(item, ResponseFunctionToolCall): - tool_calls.append(item) + message.blocks.append( + ToolCallBlock( + tool_name=item.name, + tool_call_id=item.call_id, + tool_kwargs=item.arguments, + ) + ) elif isinstance(item, ResponseFunctionWebSearch): additional_kwargs["built_in_tool_calls"].append(item) elif isinstance(item, ResponseComputerToolCall): @@ -504,9 +517,6 @@ def _parse_response_output(output: List[ResponseOutputItem]) -> ChatResponse: ) ) - if tool_calls and message: - message.additional_kwargs["tool_calls"] = tool_calls - return ChatResponse(message=message, additional_kwargs=additional_kwargs) @llm_retry_decorator @@ -542,7 +552,6 @@ def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: @staticmethod def process_response_event( event: ResponseStreamEvent, - tool_calls: List[ResponseFunctionToolCall], built_in_tool_calls: List[Any], additional_kwargs: Dict[str, Any], current_tool_call: Optional[ResponseFunctionToolCall], @@ -550,7 +559,6 @@ def process_response_event( previous_response_id: Optional[str] = None, ) -> Tuple[ List[ContentBlock], - List[ResponseFunctionToolCall], List[Any], Dict[str, Any], Optional[ResponseFunctionToolCall], @@ -591,6 +599,7 @@ def process_response_event( elif isinstance(event, ResponseTextDeltaEvent): # Text content is being added delta = event.delta + blocks.append(TextBlock(text=delta)) elif isinstance(event, ResponseImageGenCallPartialImageEvent): # Partial image if event.partial_image_b64: @@ -610,9 +619,12 @@ def process_response_event( current_tool_call.arguments = event.arguments current_tool_call.status = "completed" - # append a copy of the tool call to the list - tool_calls.append( - ResponseFunctionToolCall(**current_tool_call.model_dump()) + blocks.append( + ToolCallBlock( + tool_name=current_tool_call.name, + tool_kwargs=current_tool_call.arguments, + tool_call_id=current_tool_call.call_id, + ) ) # clear the current tool call @@ -658,7 +670,6 @@ def process_response_event( return ( blocks, - tool_calls, built_in_tool_calls, additional_kwargs, current_tool_call, @@ -677,7 +688,6 @@ def _stream_chat( ) def gen() -> ChatResponseGen: - tool_calls = [] built_in_tool_calls = [] additional_kwargs = {"built_in_tool_calls": []} current_tool_call: Optional[ResponseFunctionToolCall] = None @@ -691,7 +701,6 @@ def gen() -> ChatResponseGen: # Process the event and update state ( blocks, - tool_calls, built_in_tool_calls, additional_kwargs, current_tool_call, @@ -699,7 +708,6 @@ def gen() -> ChatResponseGen: delta, ) = OpenAIResponses.process_response_event( event=event, - tool_calls=tool_calls, built_in_tool_calls=built_in_tool_calls, additional_kwargs=additional_kwargs, current_tool_call=current_tool_call, @@ -721,9 +729,6 @@ def gen() -> ChatResponseGen: message=ChatMessage( role=MessageRole.ASSISTANT, blocks=blocks, - additional_kwargs={"tool_calls": tool_calls} - if tool_calls - else {}, ), delta=delta, raw=event, @@ -801,7 +806,6 @@ async def _astream_chat( ) async def gen() -> ChatResponseAsyncGen: - tool_calls = [] built_in_tool_calls = [] additional_kwargs = {"built_in_tool_calls": []} current_tool_call: Optional[ResponseFunctionToolCall] = None @@ -817,7 +821,6 @@ async def gen() -> ChatResponseAsyncGen: # Process the event and update state ( blocks, - tool_calls, built_in_tool_calls, additional_kwargs, current_tool_call, @@ -825,7 +828,6 @@ async def gen() -> ChatResponseAsyncGen: delta, ) = OpenAIResponses.process_response_event( event=event, - tool_calls=tool_calls, built_in_tool_calls=built_in_tool_calls, additional_kwargs=additional_kwargs, current_tool_call=current_tool_call, @@ -847,9 +849,6 @@ async def gen() -> ChatResponseAsyncGen: message=ChatMessage( role=MessageRole.ASSISTANT, blocks=blocks, - additional_kwargs={"tool_calls": tool_calls} - if tool_calls - else {}, ), delta=delta, raw=event, @@ -915,9 +914,11 @@ def get_tool_calls_from_response( **kwargs: Any, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls: List[ResponseFunctionToolCall] = ( - response.message.additional_kwargs.get("tool_calls", []) - ) + tool_calls: List[ToolCallBlock] = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: @@ -931,14 +932,14 @@ def get_tool_calls_from_response( for tool_call in tool_calls: # this should handle both complete and partial jsons try: - argument_dict = parse_partial_json(tool_call.arguments) - except ValueError: + argument_dict = parse_partial_json(cast(str, tool_call.tool_kwargs)) + except Exception: argument_dict = {} tool_selections.append( ToolSelection( - tool_id=tool_call.call_id, - tool_name=tool_call.name, + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, tool_kwargs=argument_dict, ) ) diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py index c7536bda56..f0d78a3cae 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py @@ -30,6 +30,8 @@ AudioBlock, DocumentBlock, ThinkingBlock, + ToolCallBlock, + ContentBlock, ) from llama_index.core.bridge.pydantic import BaseModel @@ -398,6 +400,14 @@ def to_openai_message_dict( }, } ) + elif isinstance(block, ToolCallBlock): + try: + content.append({"type": "text", "text": block.model_dump_json()}) + except Exception: + logger.warning( + f"It was not possible to convert ToolCallBlock with call id {block.tool_call_id or '`no call id`'} to a valid message, skipping..." + ) + continue else: msg = f"Unsupported content block type: {type(block).__name__}" raise ValueError(msg) @@ -515,6 +525,14 @@ def to_openai_responses_message_dict( if block.content: content.append({"type": "output_text", "text": block.content}) content_txt += block.content + elif isinstance(block, ToolCallBlock): + try: + content.append({"type": "output_text", "text": block.model_dump_json()}) + except Exception: + logger.warning( + f"It was not possible to convert ToolCallBlock with call id {block.tool_call_id or '`no call id`'} to a valid message, skipping..." + ) + continue else: msg = f"Unsupported content block type: {type(block).__name__}" raise ValueError(msg) @@ -553,13 +571,6 @@ def to_openai_responses_message_dict( } return message_dict - elif "tool_calls" in message.additional_kwargs: - message_dicts = [ - tool_call if isinstance(tool_call, dict) else tool_call.model_dump() - for tool_call in message.additional_kwargs["tool_calls"] - ] - - return message_dicts # there are some cases (like image generation or MCP tool call) that only support the string input # this is why, if context_txt is a non-empty string, all the blocks are TextBlocks and the role is user, we return directly context_txt @@ -648,13 +659,22 @@ def from_openai_message( role = openai_message.role # NOTE: Azure OpenAI returns function calling messages without a content key if "text" in modalities and openai_message.content: - blocks = [TextBlock(text=openai_message.content or "")] + blocks: List[ContentBlock] = [TextBlock(text=openai_message.content or "")] else: - blocks = [] + blocks: List[ContentBlock] = [] additional_kwargs: Dict[str, Any] = {} if openai_message.tool_calls: tool_calls: List[ChatCompletionMessageToolCall] = openai_message.tool_calls + for tool_call in tool_calls: + if tool_call.function: + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_name=tool_call.function.name or "", + tool_kwargs=tool_call.function.arguments or {}, + ) + ) additional_kwargs.update(tool_calls=tool_calls) if openai_message.audio and "audio" in modalities: @@ -742,6 +762,14 @@ def from_openai_message_dict(message_dict: dict) -> ChatMessage: blocks.append(ImageBlock(image=img, detail=detail)) else: blocks.append(ImageBlock(url=img, detail=detail)) + elif t == "function_call": + blocks.append( + ToolCallBlock( + tool_call_id=elem.get("call_id"), + tool_name=elem.get("name", ""), + tool_kwargs=elem.get("arguments", {}), + ) + ) else: msg = f"Unsupported message type: {t}" raise ValueError(msg) diff --git a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_llms_openai.py b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_llms_openai.py index 6b0adbd564..5e39ad1e70 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_llms_openai.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_llms_openai.py @@ -4,6 +4,7 @@ import pytest from llama_index.llms.openai import OpenAI from llama_index.llms.openai.utils import resolve_tool_choice +from llama_index.core.base.llms.types import ToolCallBlock def test_text_inference_embedding_class(): @@ -165,6 +166,16 @@ def test_tool_required(): ) print(repr(response)) assert len(response.message.additional_kwargs["tool_calls"]) == 1 + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 1 + ) @pytest.mark.skipif( diff --git a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai.py b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai.py index 184728a8d8..9c7fd30eed 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai.py @@ -369,7 +369,7 @@ def test_chat_model_streaming(MockSyncOpenAI: MagicMock) -> None: ) chat_response_gen = llm.stream_chat([message]) chat_responses = list(chat_response_gen) - assert chat_responses[-1].message.content == "\n\n2" + assert chat_responses[-1].message.blocks[-1].text == "\n\n2" assert chat_responses[-1].message.role == "assistant" diff --git a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_responses.py b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_responses.py index 48ef4e5c37..b6720160e4 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_responses.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_responses.py @@ -12,6 +12,7 @@ DocumentBlock, ChatResponse, ThinkingBlock, + ToolCallBlock, ) from llama_index.llms.openai.responses import OpenAIResponses, ResponseFunctionToolCall from llama_index.llms.openai.utils import to_openai_message_dicts @@ -127,7 +128,6 @@ def test_parse_response_output(): def test_process_response_event(): """Test the static process_response_event method for streaming responses.""" # Initial state - tool_calls = [] built_in_tool_calls = [] additional_kwargs = {} current_tool_call = None @@ -145,17 +145,15 @@ def test_process_response_event(): result = OpenAIResponses.process_response_event( event=event, - tool_calls=tool_calls, built_in_tool_calls=built_in_tool_calls, additional_kwargs=additional_kwargs, current_tool_call=current_tool_call, track_previous_responses=False, ) - updated_blocks, updated_tool_calls, _, _, _, _, delta = result - assert updated_blocks == [] + updated_blocks, _, _, _, _, delta = result + assert updated_blocks == [TextBlock(text="Hello")] assert delta == "Hello" - assert updated_tool_calls == [] event = ResponseOutputItemDoneEvent( item=ResponseReasoningItem( @@ -176,14 +174,13 @@ def test_process_response_event(): result = OpenAIResponses.process_response_event( event=event, - tool_calls=tool_calls, built_in_tool_calls=built_in_tool_calls, additional_kwargs=additional_kwargs, current_tool_call=current_tool_call, track_previous_responses=False, ) - updated_blocks, _, _, _, _, _, _ = result + updated_blocks, _, _, _, _, _ = result assert updated_blocks == [ ThinkingBlock( block_type="thinking", @@ -209,7 +206,6 @@ def test_process_response_event(): ) event = ResponseFunctionCallArgumentsDeltaEvent( - content_index=0, item_id="123", output_index=0, type="response.function_call_arguments.delta", @@ -219,14 +215,13 @@ def test_process_response_event(): result = OpenAIResponses.process_response_event( event=event, - tool_calls=updated_tool_calls, built_in_tool_calls=built_in_tool_calls, additional_kwargs=additional_kwargs, current_tool_call=current_tool_call, track_previous_responses=False, ) - _, _, _, _, updated_call, _, _ = result + _, _, _, updated_call, _, _ = result assert updated_call.arguments == '{"arg": "value"' # Test function call arguments done @@ -240,23 +235,25 @@ def test_process_response_event(): result = OpenAIResponses.process_response_event( event=event, - tool_calls=updated_tool_calls, built_in_tool_calls=built_in_tool_calls, additional_kwargs=additional_kwargs, current_tool_call=updated_call, track_previous_responses=False, ) - _, completed_tool_calls, _, _, final_current_call, _, _ = result + final_blocks, _, _, final_current_call, _, _ = result + completed_tool_calls = [ + block for block in final_blocks if isinstance(block, ToolCallBlock) + ] assert len(completed_tool_calls) == 1 - assert completed_tool_calls[0].arguments == '{"arg": "value"}' - assert completed_tool_calls[0].status == "completed" + assert completed_tool_calls[0].tool_kwargs == '{"arg": "value"}' + assert completed_tool_calls[0].tool_call_id == "123" + assert completed_tool_calls[0].tool_name == "test_function" assert final_current_call is None def test_process_response_event_with_text_annotation(): """Test process_response_event handles ResponseOutputTextAnnotationAddedEvent.""" - tool_calls = [] built_in_tool_calls = [] additional_kwargs = {} current_tool_call = None @@ -274,7 +271,6 @@ def test_process_response_event_with_text_annotation(): result = OpenAIResponses.process_response_event( event=event, - tool_calls=tool_calls, built_in_tool_calls=built_in_tool_calls, additional_kwargs=additional_kwargs, current_tool_call=current_tool_call, @@ -282,7 +278,7 @@ def test_process_response_event_with_text_annotation(): ) # The annotation should be added to additional_kwargs["annotations"] - _, _, _, updated_additional_kwargs, _, _, _ = result + _, _, updated_additional_kwargs, _, _, _ = result assert "annotations" in updated_additional_kwargs assert updated_additional_kwargs["annotations"] == [ {"type": "test_annotation", "value": 42} @@ -291,18 +287,15 @@ def test_process_response_event_with_text_annotation(): def test_get_tool_calls_from_response(): """Test extracting tool calls from a chat response.""" - tool_call = ResponseFunctionToolCall( - id="call_123", - call_id="123", - type="function_call", - name="test_function", - arguments='{"arg1": "value1", "arg2": 42}', - status="completed", - ) - # Create a mock chat response with tool calls chat_response = MagicMock() - chat_response.message.additional_kwargs = {"tool_calls": [tool_call]} + chat_response.message.blocks = [ + ToolCallBlock( + tool_call_id="123", + tool_name="test_function", + tool_kwargs='{"arg1": "value1", "arg2": 42}', + ) + ] with ( patch("llama_index.llms.openai.responses.SyncOpenAI"), @@ -606,15 +599,44 @@ def test_tool_required(): tools=[search_tool], tool_required=True, ) - assert len(response.message.additional_kwargs["tool_calls"]) == 1 + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 1 + ) def test_messages_to_openai_responses_messages(): messages = [ ChatMessage(role=MessageRole.SYSTEM, content="You are a helpful assistant."), ChatMessage(role=MessageRole.USER, content="What is the capital of France?"), + ChatMessage( + role=MessageRole.ASSISTANT, + blocks=[ + ToolCallBlock( + tool_call_id="1", + tool_name="get_capital_city_by_state", + tool_kwargs="{'state': 'France'}", + ) + ], + ), ChatMessage(role=MessageRole.ASSISTANT, content="Paris"), ChatMessage(role=MessageRole.USER, content="What is the capital of Germany?"), + ChatMessage( + role=MessageRole.ASSISTANT, + blocks=[ + ToolCallBlock( + tool_call_id="2", + tool_name="get_capital_city_by_state", + tool_kwargs="{'state': 'Germany'}", + ) + ], + ), ChatMessage( role=MessageRole.ASSISTANT, blocks=[ @@ -626,19 +648,37 @@ def test_messages_to_openai_responses_messages(): ), ] openai_messages = to_openai_message_dicts(messages, is_responses_api=True) - assert len(openai_messages) == 5 + assert len(openai_messages) == 7 assert openai_messages[0]["role"] == "developer" assert openai_messages[0]["content"] == "You are a helpful assistant." assert openai_messages[1]["role"] == "user" assert openai_messages[1]["content"] == "What is the capital of France?" assert openai_messages[2]["role"] == "assistant" - assert openai_messages[2]["content"] == "Paris" - assert openai_messages[3]["role"] == "user" - assert openai_messages[3]["content"] == "What is the capital of Germany?" - assert openai_messages[4]["role"] == "assistant" - assert len(openai_messages[4]["content"]) == 2 - assert openai_messages[4]["content"][0]["text"] == messages[4].blocks[0].content - assert openai_messages[4]["content"][1]["text"] == messages[4].blocks[1].text + assert ( + openai_messages[2]["content"][0]["text"] + == ToolCallBlock( + tool_call_id="1", + tool_name="get_capital_city_by_state", + tool_kwargs="{'state': 'France'}", + ).model_dump_json() + ) + assert openai_messages[3]["role"] == "assistant" + assert openai_messages[3]["content"] == "Paris" + assert openai_messages[4]["role"] == "user" + assert openai_messages[4]["content"] == "What is the capital of Germany?" + assert openai_messages[5]["role"] == "assistant" + assert ( + openai_messages[5]["content"][0]["text"] + == ToolCallBlock( + tool_call_id="2", + tool_name="get_capital_city_by_state", + tool_kwargs="{'state': 'Germany'}", + ).model_dump_json() + ) + assert openai_messages[6]["role"] == "assistant" + assert len(openai_messages[6]["content"]) == 2 + assert openai_messages[6]["content"][0]["text"] == messages[6].blocks[0].content + assert openai_messages[6]["content"][1]["text"] == messages[6].blocks[1].text @pytest.fixture() @@ -682,6 +722,13 @@ def response_output() -> List[ResponseOutputItem]: encrypted_content=None, status=None, ), + ResponseFunctionToolCall( + arguments="{'hello': 'world'}", + call_id="1", + name="test", + type="function_call", + status="completed", + ), ResponseOutputMessage( id="1", content=[ @@ -715,6 +762,22 @@ def test__parse_response_output(response_output: List[ResponseOutputItem]): len([block for block in result.message.blocks if isinstance(block, TextBlock)]) == 1 ) + assert ( + len( + [ + block + for block in result.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 1 + ) + tool_call = [ + block for block in result.message.blocks if isinstance(block, ToolCallBlock) + ][0] + assert tool_call.tool_call_id == "1" + assert tool_call.tool_name == "test" + assert tool_call.tool_kwargs == "{'hello': 'world'}" assert [ block for block in result.message.blocks if isinstance(block, ThinkingBlock) ][0].content == "hello world\nthis is a test" diff --git a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_utils.py b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_utils.py index e0da0e3021..09784cd2d0 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_utils.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_utils.py @@ -20,6 +20,7 @@ LogProb, MessageRole, TextBlock, + ToolCallBlock, ) from llama_index.core.bridge.pydantic import BaseModel from llama_index.llms.openai import OpenAI @@ -117,7 +118,14 @@ def azure_chat_messages_with_function_calling() -> List[ChatMessage]: return [ ChatMessage( role=MessageRole.ASSISTANT, - content=None, + blocks=[ + ToolCallBlock( + block_type="tool_call", + tool_call_id="0123", + tool_name="search_hotels", + tool_kwargs='{\n "location": "San Diego",\n "max_price": 300,\n "features": "beachfront,free breakfast"\n}', + ) + ], additional_kwargs={ "tool_calls": [ ChatCompletionMessageToolCall(