From e7e2e929d1079b0bfe1a1907de95333b3bc85d62 Mon Sep 17 00:00:00 2001 From: kamosama <837080904@qq.com> Date: Tue, 23 Jul 2024 18:11:55 +0800 Subject: [PATCH] Fix a FunctionCallback inside a container that pollutes the global model's ChatOptions > The SpingBoot autoconfiguration class adds the container's FunctionCallback directly to the model's ChatOptions, which results in the FunctionCallback being included in the request each time it is called. > The modification registers the container's FunctionCallback directly to the model's functionCallbackRegister. --- .../anthropic/AnthropicAutoConfiguration.java | 12 +++++++++--- .../azure/openai/AzureOpenAiAutoConfiguration.java | 11 +++++++++-- .../minimax/MiniMaxAutoConfiguration.java | 11 +++++++++-- .../mistralai/MistralAiAutoConfiguration.java | 12 +++++++++--- .../openai/OpenAiAutoConfiguration.java | 10 +++++++--- .../gemini/VertexAiGeminiAutoConfiguration.java | 11 +++++++++-- .../zhipuai/ZhiPuAiAutoConfiguration.java | 12 +++++++++--- 7 files changed, 61 insertions(+), 18 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java index 0d84ba76fc8..eb84758e020 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.anthropic; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -66,12 +69,15 @@ public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, Anthropi RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks) { + AnthropicChatModel chatModel = new AnthropicChatModel(anthropicApi, chatProperties.getOptions(), retryTemplate, functionCallbackContext); + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { - chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + Map toolFunctionCallbackMap = toolFunctionCallbacks.stream() + .collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b)); + chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); } - return new AnthropicChatModel(anthropicApi, chatProperties.getOptions(), retryTemplate, - functionCallbackContext); + return chatModel; } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index e651115edac..3e31cbc8974 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -40,6 +40,9 @@ import org.springframework.util.StringUtils; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; /** * @author Piotr Olaszewski @@ -101,11 +104,15 @@ public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, AzureOpenAiChatProperties chatProperties, List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { + AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAIClient, chatProperties.getOptions(), functionCallbackContext); + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { - chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + Map toolFunctionCallbackMap = toolFunctionCallbacks.stream() + .collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b)); + chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); } - return new AzureOpenAiChatModel(openAIClient, chatProperties.getOptions(), functionCallbackContext); + return chatModel; } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java index 80cde6dfee3..ba2e6bee54b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java @@ -37,6 +37,9 @@ import org.springframework.web.client.RestClient; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; /** * @author Geng Rong @@ -59,11 +62,15 @@ public MiniMaxChatModel miniMaxChatModel(MiniMaxConnectionProperties commonPrope var miniMaxApi = miniMaxApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + MiniMaxChatModel chatModel = new MiniMaxChatModel(miniMaxApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { - chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + Map toolFunctionCallbackMap = toolFunctionCallbacks.stream() + .collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b)); + chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); } - return new MiniMaxChatModel(miniMaxApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + return chatModel; } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java index 0073311fede..2c05fec6458 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.mistralai; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.mistralai.MistralAiChatModel; @@ -81,12 +84,15 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro var mistralAiApi = mistralAiApi(chatProperties.getApiKey(), commonProperties.getApiKey(), chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder, responseErrorHandler); + MistralAiChatModel chatModel = new MistralAiChatModel(mistralAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { - chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + Map toolFunctionCallbackMap = toolFunctionCallbacks.stream() + .collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b)); + chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); } - return new MistralAiChatModel(mistralAiApi, chatProperties.getOptions(), functionCallbackContext, - retryTemplate); + return chatModel; } private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 6c334040432..c32216b486f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.openai; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.model.function.FunctionCallback; @@ -73,12 +76,13 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder, responseErrorHandler, "chat"); - + OpenAiChatModel chatModel = new OpenAiChatModel(openAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { - chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + Map toolFunctionCallbackMap = toolFunctionCallbacks.stream().collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b)); + chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); } - return new OpenAiChatModel(openAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + return chatModel; } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java index 81310360370..83f33b1db06 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java @@ -17,6 +17,9 @@ import java.io.IOException; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.vertexai.VertexAI; @@ -79,11 +82,15 @@ public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGem FunctionCallbackContext functionCallbackContext = springAiFunctionManager(context); + VertexAiGeminiChatModel chatModel = new VertexAiGeminiChatModel(vertexAi, chatProperties.getOptions(), functionCallbackContext); + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { - chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + Map toolFunctionCallbackMap = toolFunctionCallbacks.stream() + .collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b)); + chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); } - return new VertexAiGeminiChatModel(vertexAi, chatProperties.getOptions(), functionCallbackContext); + return chatModel; } /** diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java index 894b533b0bf..45696583f7b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java @@ -39,6 +39,9 @@ import org.springframework.web.client.RestClient; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; /** * @author Geng Rong @@ -61,11 +64,14 @@ public ZhiPuAiChatModel zhiPuAiChatModel(ZhiPuAiConnectionProperties commonPrope var zhiPuAiApi = zhiPuAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(zhiPuAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { - chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); - } + Map toolFunctionCallbackMap = toolFunctionCallbacks.stream() + .collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b)); + chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); } - return new ZhiPuAiChatModel(zhiPuAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + return chatModel; } @Bean