diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index b0fffcf3327..5f8da416109 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -23,6 +23,7 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; @@ -120,7 +121,7 @@ public String call(String functionInput) { new IllegalStateException("Error calling tool: " + response.content())); } return ModelOptionsUtils.toJsonString(response.content()); - }).block(); + }).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block(); } @Override diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 270f3bef43d..6456120f325 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -64,6 +64,7 @@ import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; @@ -263,8 +264,14 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous - return Flux.defer(() -> { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(chatResponse) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 1933f575300..3f659671c4d 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -95,6 +95,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; @@ -380,8 +381,15 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous - return Flux.defer(() -> { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + } + finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder() diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index d30f2517756..484e979385e 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -101,6 +101,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.util.Assert; @@ -681,8 +682,15 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous - return Flux.defer(() -> { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + } + finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index 4b7607c6e38..6295666e07f 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -62,6 +62,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; @@ -286,10 +287,16 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { - return Flux.defer(() -> { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(response) diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index e5a774cacf9..19f821b7fb3 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -65,6 +65,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.http.ResponseEntity; @@ -370,10 +371,16 @@ public Flux stream(Prompt prompt) { Flux flux = chatResponse.flatMap(response -> { if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { - return Flux.defer(() -> { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response); + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(response) diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index b9838dcedf1..a11548e45b6 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -64,6 +64,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; @@ -316,8 +317,14 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous - return Flux.defer(() -> { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(response) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 44dc45347b6..c6bd6c2676e 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -54,6 +54,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; @@ -351,8 +352,14 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous - return Flux.defer(() -> { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(response) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index a4a84a78054..2ad584fa82f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -61,6 +61,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; @@ -363,10 +364,16 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { - return Flux.defer(() -> { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(response) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 01ab8b96c02..852678a1da3 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -81,6 +81,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; @@ -540,9 +541,15 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha Flux flux = chatResponseFlux.flatMap(response -> { if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - return Flux.defer(() -> { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + // is currently only synchronous + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(response) diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 408666fdc34..80690410273 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -56,6 +56,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; @@ -357,10 +358,16 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux flux = chatResponse.flatMap(response -> { if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { - return Flux.defer(() -> { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response); + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual((ctx) -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. return Flux.just(ChatResponse.builder().from(response) diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/internal/ToolCallReactiveContextHolder.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/internal/ToolCallReactiveContextHolder.java new file mode 100644 index 00000000000..73d5764e667 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/internal/ToolCallReactiveContextHolder.java @@ -0,0 +1,30 @@ +package org.springframework.ai.model.tool.internal; + +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +/** + * This class bridges blocking Tools call and the reactive context. When calling tools, it + * captures the context in a thread local, making it available to re-inject in a nested + * reactive call. + * + * @author Daniel Garnier-Moiroux + * @since 1.1.0 + */ +public class ToolCallReactiveContextHolder { + + private static final ThreadLocal context = ThreadLocal.withInitial(Context::empty); + + public static void setContext(ContextView contextView) { + context.set(contextView); + } + + public static ContextView getContext() { + return context.get(); + } + + public static void clearContext() { + context.remove(); + } + +}