diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 8a3925f3b71..dfa40267414 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -225,7 +225,8 @@ public ChatResponse call(Prompt prompt) { return chatResponse; }); - if (response != null && this.isToolCall(response, Set.of("tool_use"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null + && this.isToolCall(response, Set.of("tool_use"))) { var toolCallConversation = handleToolCalls(prompt, response); return this.call(new Prompt(toolCallConversation, prompt.getOptions())); } @@ -256,7 +257,7 @@ public Flux stream(Prompt prompt) { Flux chatResponseFlux = response.switchMap(chatCompletionResponse -> { ChatResponse chatResponse = toChatResponse(chatCompletionResponse); - if (this.isToolCall(chatResponse, Set.of("tool_use"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 03daa19524e..e79a8f760f0 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -77,6 +77,9 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions @NestedConfigurationProperty @JsonIgnore private Set functions = new HashSet<>(); + + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on public static Builder builder() { @@ -144,6 +147,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public AnthropicChatOptions build() { return this.options; } @@ -246,6 +254,15 @@ public Double getPresencePenalty() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public AnthropicChatOptions copy() { return fromOptions(this); @@ -261,6 +278,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .withTopK(fromOptions.getTopK()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 1b2a7a72483..bffdbe2c744 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -151,7 +151,8 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = toChatResponse(chatCompletions); - if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -199,7 +200,8 @@ public Flux stream(Prompt prompt) { ChatResponse chatResponse = toChatResponse(chatCompletions); - if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, + Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index fc4c0f1795c..6b85eeb9662 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -31,6 +31,7 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; +import org.stringtemplate.v4.compiler.CodeGenerator.primary_return; /** * The configuration information for a chat completions request. Completions support a @@ -161,6 +162,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio @JsonIgnore private Set functions = new HashSet<>(); + @JsonIgnore + private Boolean proxyToolCalls; + public static Builder builder() { return new Builder(); } @@ -250,6 +254,11 @@ public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } @@ -395,6 +404,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public AzureOpenAiChatOptions copy() { return fromOptions(this); @@ -413,6 +431,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .withUser(fromOptions.getUser()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withResponseFormat(fromOptions.getResponseFormat()) .build(); } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index ee3f758461a..c14f93cbe4a 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -190,7 +190,7 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - if (isToolCall(chatResponse, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message @@ -254,7 +254,7 @@ public Flux stream(Prompt prompt) { return chatResponse.flatMap(response -> { - if (isToolCall(response, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 31cae5791fa..30426eb6c7c 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -142,6 +142,9 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions { @NestedConfigurationProperty @JsonIgnore private Set functions = new HashSet<>(); + + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on public static Builder builder() { @@ -242,6 +245,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public MiniMaxChatOptions build() { return this.options; } @@ -394,6 +402,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public int hashCode() { final int prime = 31; @@ -411,6 +428,7 @@ public int hashCode() { result = prime * result + ((maskSensitiveInfo == null) ? 0 : maskSensitiveInfo.hashCode()); result = prime * result + ((tools == null) ? 0 : tools.hashCode()); result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); return result; } @@ -501,6 +519,12 @@ else if (!tools.equals(other.tools)) } else if (!toolChoice.equals(other.toolChoice)) return false; + if (this.proxyToolCalls == null) { + if (other.proxyToolCalls != null) + return false; + } + else if (!proxyToolCalls.equals(other.proxyToolCalls)) + return false; return true; } @@ -525,6 +549,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .withToolChoice(fromOptions.getToolChoice()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index ceea2e869df..edf1da14f44 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -183,8 +183,9 @@ public ChatResponse call(Prompt prompt) { return chatResponse; }); - if (response != null && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - MistralAiApi.ChatCompletionFinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null + && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + MistralAiApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -255,7 +256,7 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index dc4fcdc6dc4..7053f5ab0b0 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -135,6 +135,9 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions @JsonIgnore private Set functions = new HashSet<>(); + @JsonIgnore + private Boolean proxyToolCalls; + public static Builder builder() { return new Builder(); } @@ -215,6 +218,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public MistralAiChatOptions build() { return this.options; } @@ -356,6 +364,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public MistralAiChatOptions copy() { return fromOptions(this); @@ -374,7 +391,114 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .withToolChoice(fromOptions.getToolChoice()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((safePrompt == null) ? 0 : safePrompt.hashCode()); + result = prime * result + ((randomSeed == null) ? 0 : randomSeed.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode()); + result = prime * result + ((functions == null) ? 0 : functions.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + MistralAiChatOptions other = (MistralAiChatOptions) obj; + if (model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (temperature == null) { + if (other.temperature != null) + return false; + } + else if (!temperature.equals(other.temperature)) + return false; + if (topP == null) { + if (other.topP != null) + return false; + } + else if (!topP.equals(other.topP)) + return false; + if (maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!maxTokens.equals(other.maxTokens)) + return false; + if (safePrompt == null) { + if (other.safePrompt != null) + return false; + } + else if (!safePrompt.equals(other.safePrompt)) + return false; + if (randomSeed == null) { + if (other.randomSeed != null) + return false; + } + else if (!randomSeed.equals(other.randomSeed)) + return false; + if (responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!responseFormat.equals(other.responseFormat)) + return false; + if (stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if (tools == null) { + if (other.tools != null) + return false; + } + else if (!tools.equals(other.tools)) + return false; + if (toolChoice != other.toolChoice) + return false; + if (functionCallbacks == null) { + if (other.functionCallbacks != null) + return false; + } + else if (!functionCallbacks.equals(other.functionCallbacks)) + return false; + if (functions == null) { + if (other.functions != null) + return false; + } + else if (!functions.equals(other.functions)) + return false; + if (proxyToolCalls == null) { + if (other.proxyToolCalls != null) + return false; + } + else if (!proxyToolCalls.equals(other.proxyToolCalls)) + return false; + return true; + } + } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 99956e81e93..553eab3713d 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -164,8 +164,9 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - if (isToolCall(chatResponse, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - MoonshotApi.ChatCompletionFinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(chatResponse, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + MoonshotApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -228,7 +229,8 @@ public Flux stream(Prompt prompt) { return chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index 4bf51bca52b..6eedc5ef1fb 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -137,6 +137,9 @@ public class MoonshotChatOptions implements FunctionCallingOptions, ChatOptions */ private @JsonProperty("user") String user; + @JsonIgnore + private Boolean proxyToolCalls; + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -244,6 +247,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public MoonshotChatOptions build() { return this.options; } @@ -345,6 +353,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public MoonshotChatOptions copy() { return builder().withModel(this.model) @@ -360,6 +377,7 @@ public MoonshotChatOptions copy() { .withToolChoice(this.toolChoice) .withFunctionCallbacks(this.functionCallbacks) .withFunctions(this.functions) + .withProxyToolCalls(this.proxyToolCalls) .build(); } @@ -376,6 +394,7 @@ public int hashCode() { result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); result = prime * result + ((topP == null) ? 0 : topP.hashCode()); result = prime * result + ((user == null) ? 0 : user.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); return result; } @@ -441,6 +460,11 @@ else if (!topP.equals(other.topP)) } else if (!this.user.equals(other.user)) return false; + if (this.proxyToolCalls == null) { + return other.proxyToolCalls == null; + } + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) + return false; return true; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 79dced2155c..96e2a1267a4 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -162,7 +162,8 @@ public ChatResponse call(Prompt prompt) { }); - if (response != null && isToolCall(response, Set.of("stop"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null + && isToolCall(response, Set.of("stop"))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 3fd22d03e4a..530e4361f18 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -270,7 +270,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed /** - * Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. + * Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. * Defaults to true. */ @JsonProperty("truncate") private Boolean truncate; @@ -297,6 +297,8 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed @JsonIgnore private Set functions = new HashSet<>(); + @JsonIgnore + private Boolean proxyToolCalls; public static OllamaOptions builder() { return new OllamaOptions(); @@ -495,6 +497,11 @@ public OllamaOptions withFunction(String functionName) { return this; } + public OllamaOptions withProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + return this; + } + // ------------------- // Getters and Setters // ------------------- @@ -816,6 +823,15 @@ public Integer getDimensions() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. @@ -884,6 +900,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .withPenalizeNewline(fromOptions.getPenalizeNewline()) .withStop(fromOptions.getStop()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()); } // @formatter:on @@ -913,7 +930,7 @@ public boolean equals(Object o) { && Objects.equals(mirostatTau, that.mirostatTau) && Objects.equals(mirostatEta, that.mirostatEta) && Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop) && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(functions, that.functions); + && Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions); } @Override @@ -923,7 +940,7 @@ public int hashCode() { this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.functionCallbacks, this.functions); + this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 55f999813b7..3c04ebbc0bb 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -15,9 +15,16 @@ */ package org.springframework.ai.openai; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; @@ -63,19 +70,13 @@ import org.springframework.util.MimeType; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; - /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} * backed by {@link OpenAiApi}. @@ -189,6 +190,7 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + super(functionCallbackContext, options, toolFunctionCallbacks); Assert.notNull(openAiApi, "OpenAiApi must not be null"); @@ -259,8 +261,9 @@ public ChatResponse call(Prompt prompt) { }); - if (response != null && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -330,7 +333,7 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 3a2d8695b80..e89c560e196 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import org.springframework.ai.chat.prompt.ChatOptions; @@ -171,6 +172,14 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions { @JsonIgnore private Set functions = new HashSet<>(); + /** + * If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. + * It is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. + * If false, the Spring AI will handle the function calls internally. + */ + @JsonIgnore + private Boolean proxyToolCalls; + /** * Optional HTTP headers to be added to the chat completion request. */ @@ -307,8 +316,12 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public Builder withHttpHeaders(Map httpHeaders) { - Assert.notNull(httpHeaders, "HTTP headers must not be null"); this.options.httpHeaders = httpHeaders; return this; } @@ -468,6 +481,15 @@ public String getToolChoice() { return this.toolChoice; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + public void setToolChoice(String toolChoice) { this.toolChoice = toolChoice; } @@ -521,152 +543,6 @@ public Integer getTopK() { return null; } - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode()); - result = prime * result + ((logprobs == null) ? 0 : logprobs.hashCode()); - result = prime * result + ((topLogprobs == null) ? 0 : topLogprobs.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((n == null) ? 0 : n.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((streamOptions == null) ? 0 : streamOptions.hashCode()); - result = prime * result + ((seed == null) ? 0 : seed.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((tools == null) ? 0 : tools.hashCode()); - result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); - result = prime * result + ((user == null) ? 0 : user.hashCode()); - result = prime * result + ((parallelToolCalls == null) ? 0 : parallelToolCalls.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - OpenAiChatOptions other = (OpenAiChatOptions) obj; - if (this.model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) - return false; - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) - return false; - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) - return false; - if (this.logitBias == null) { - if (other.logitBias != null) - return false; - } - else if (!this.logitBias.equals(other.logitBias)) - return false; - if (this.logprobs == null) { - if (other.logprobs != null) - return false; - } - else if (!this.logprobs.equals(other.logprobs)) - return false; - if (this.topLogprobs == null) { - if (other.topLogprobs != null) - return false; - } - else if (!this.topLogprobs.equals(other.topLogprobs)) - return false; - if (this.maxTokens == null) { - if (other.maxTokens != null) - return false; - } - else if (!this.maxTokens.equals(other.maxTokens)) - return false; - if (this.n == null) { - if (other.n != null) - return false; - } - else if (!this.n.equals(other.n)) - return false; - if (this.presencePenalty == null) { - if (other.presencePenalty != null) - return false; - } - else if (!this.presencePenalty.equals(other.presencePenalty)) - return false; - if (this.responseFormat == null) { - if (other.responseFormat != null) - return false; - } - else if (!this.responseFormat.equals(other.responseFormat)) - return false; - if (this.streamOptions == null) { - if (other.streamOptions != null) - return false; - } - else if (!this.streamOptions.equals(other.streamOptions)) - return false; - if (this.seed == null) { - if (other.seed != null) - return false; - } - else if (!this.seed.equals(other.seed)) - return false; - if (this.stop == null) { - if (other.stop != null) - return false; - } - else if (!stop.equals(other.stop)) - return false; - if (this.temperature == null) { - if (other.temperature != null) - return false; - } - else if (!this.temperature.equals(other.temperature)) - return false; - if (this.topP == null) { - if (other.topP != null) - return false; - } - else if (!topP.equals(other.topP)) - return false; - if (this.tools == null) { - if (other.tools != null) - return false; - } - else if (!tools.equals(other.tools)) - return false; - if (this.toolChoice == null) { - if (other.toolChoice != null) - return false; - } - else if (!toolChoice.equals(other.toolChoice)) - return false; - if (this.user == null) { - if (other.user != null) - return false; - } - else if (!this.user.equals(other.user)) - return false; - else if (this.parallelToolCalls == null) { - if (other.parallelToolCalls != null) - return false; - } - else if (!this.parallelToolCalls.equals(other.parallelToolCalls)) - return false; - - return true; - } - @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -695,9 +571,42 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withHttpHeaders(fromOptions.getHttpHeaders()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } + @Override + public int hashCode() { + return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, + this.maxTokens, this.n, this.presencePenalty, this.responseFormat, this.streamOptions, this.seed, + this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, + this.functionCallbacks, this.functions, this.httpHeaders, this.proxyToolCalls); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + OpenAiChatOptions other = (OpenAiChatOptions) o; + return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.logitBias, other.logitBias) && Objects.equals(this.logprobs, other.logprobs) + && Objects.equals(this.topLogprobs, other.topLogprobs) + && Objects.equals(this.maxTokens, other.maxTokens) && Objects.equals(this.n, other.n) + && Objects.equals(this.presencePenalty, other.presencePenalty) + && Objects.equals(this.responseFormat, other.responseFormat) + && Objects.equals(this.streamOptions, other.streamOptions) && Objects.equals(this.seed, other.seed) + && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools) + && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user) + && Objects.equals(this.parallelToolCalls, other.parallelToolCalls) + && Objects.equals(this.functionCallbacks, other.functionCallbacks) + && Objects.equals(this.functions, other.functions) + && Objects.equals(this.httpHeaders, other.httpHeaders) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls); + } + @Override public String toString() { return "OpenAiChatOptions: " + ModelOptionsUtils.toJsonString(this); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 64e0d58cef6..b88f795063f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -756,6 +756,8 @@ public MediaContent(ImageUrl imageUrl) { /** * The relevant tool call. * + * @param index The index of the tool call in the list of tool calls. Required in + * case of streaming. * @param id The ID of the tool call. This ID must be referenced when you submit * the tool outputs in using the Submit tool outputs to run endpoint. * @param type The type of tool call the output is required for. For now, this is @@ -764,9 +766,14 @@ public MediaContent(ImageUrl imageUrl) { */ @JsonInclude(Include.NON_NULL) public record ToolCall(// @formatter:off + @JsonProperty("index") Integer index, @JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) {// @formatter:on + + public ToolCall(String id, String type, ChatCompletionFunction function) { + this(null, id, type, function); + } } /** diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 6bfc6faa5eb..37a619ab924 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -64,7 +64,7 @@ @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") -class OpenAiChatModelIT extends AbstractIT { +public class OpenAiChatModelIT extends AbstractIT { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java new file mode 100644 index 00000000000..c4e1198ee44 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java @@ -0,0 +1,372 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.chat; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.ToolCallHelper; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.util.CollectionUtils; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.micrometer.observation.ObservationRegistry; +import reactor.core.publisher.Flux; + +@SpringBootTest(classes = OpenAiChatModelProxyToolCallsIT.Config.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +class OpenAiChatModelProxyToolCallsIT { + + private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class); + + private static final String DEFAULT_MODEL = "gpt-4o-mini"; + + @Autowired + private OpenAiChatModel chatModel; + + // Helper class that reuses some of the {@link AbstractToolCallSupport} functionality + // to help to implement the function call handling logic on the client side. + private ToolCallHelper toolCallHelper = new ToolCallHelper(); + + // Function which will be called by the AI model. + private String getWeatherInLocation(String location, String unit) { + + double temperature = 0; + + if (location.contains("Paris")) { + temperature = 15; + } + else if (location.contains("Tokyo")) { + temperature = 10; + } + else if (location.contains("San Francisco")) { + temperature = 30; + } + + return String.format("The weather in %s is %s%s", location, temperature, unit); + } + + FunctionCallback functionDefinition = new ToolCallHelper.FunctionDefinition("getWeatherInLocation", + "Get the weather in location", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "unit"] + } + """); + + @Test + void functionCall() throws JsonMappingException, JsonProcessingException { + + List messages = List + .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); + + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + + var prompt = new Prompt(messages, promptOptions); + + boolean isToolCall = false; + + ChatResponse chatResponse = null; + + do { + + chatResponse = chatModel.call(prompt); + + // We will have to convert the chatResponse into OpenAI assistant message. + + // Note that the tool call check could be platform specific because the finish + // reasons. + isToolCall = toolCallHelper.isToolCall(chatResponse, + Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name())); + + if (isToolCall) { + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + assertThat(toolCallGeneration).isNotEmpty(); + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + var functionName = toolCall.name(); + + assertThat(functionName).isEqualTo("getWeatherInLocation"); + + String functionArguments = toolCall.arguments(); + + @SuppressWarnings("unchecked") + Map argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class); + + String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), + argumentsMap.get("unit").toString()); + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, + ModelOptionsUtils.toJsonString(functionResponse))); + } + + ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); + + List toolCallConversation = toolCallHelper.buildToolCallConversation(prompt.getInstructions(), + assistantMessage, toolMessageResponse); + + assertThat(toolCallConversation).isNotEmpty(); + + prompt = new Prompt(toolCallConversation, prompt.getOptions()); + } + } + while (isToolCall); + + logger.info("Response: {}", chatResponse); + + assertThat(chatResponse.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + + @Test + void functionStream() throws JsonMappingException, JsonProcessingException { + + List messages = List + .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); + + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + + var prompt = new Prompt(messages, promptOptions); + + String response = processToolCall(prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> { + + var functionName = toolCall.name(); + + assertThat(functionName).isEqualTo("getWeatherInLocation"); + + String functionArguments = toolCall.arguments(); + + Map argumentsMap = getFunctionArguments(functionArguments); + + String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), + argumentsMap.get("unit").toString()); + + return functionResponse; + }) + .collectList() + .block() + .stream() + .map(cr -> cr.getResult().getOutput().getContent()) + .collect(Collectors.joining()); + + logger.info("Response: {}", response); + + assertThat(response).contains("30", "10", "15"); + + } + + private Flux processToolCall(Prompt prompt, Set finishReasons, + Function customFunction) { + + Flux chatResponses = chatModel.stream(prompt); + + return chatResponses.flatMap(chatResponse -> { + + boolean isToolCall = toolCallHelper.isToolCall(chatResponse, finishReasons); + + if (isToolCall) { + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + assertThat(toolCallGeneration).isNotEmpty(); + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + String functionResponse = customFunction.apply(toolCall); + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), + ModelOptionsUtils.toJsonString(functionResponse))); + } + + ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); + + List toolCallConversation = toolCallHelper.buildToolCallConversation(prompt.getInstructions(), + assistantMessage, toolMessageResponse); + + assertThat(toolCallConversation).isNotEmpty(); + + var prompt2 = new Prompt(toolCallConversation, prompt.getOptions()); + + return processToolCall(prompt2, finishReasons, customFunction); + } + + return Flux.just(chatResponse); + }); + } + + @Test + void functionCall2() throws JsonMappingException, JsonProcessingException { + + List messages = List + .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); + + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + + var prompt = new Prompt(messages, promptOptions); + + ChatResponse chatResponse = toolCallHelper.processCall(chatModel, prompt, + Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()), + toolCall -> { + + var functionName = toolCall.name(); + + assertThat(functionName).isEqualTo("getWeatherInLocation"); + + String functionArguments = toolCall.arguments(); + + Map argumentsMap = getFunctionArguments(functionArguments); + + String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), + argumentsMap.get("unit").toString()); + + return functionResponse; + }); + + logger.info("Response: {}", chatResponse); + + assertThat(chatResponse.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + + @Test + void functionStream2() throws JsonMappingException, JsonProcessingException { + + List messages = List + .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); + + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + + var prompt = new Prompt(messages, promptOptions); + + Flux responses = toolCallHelper.processStream(chatModel, prompt, + Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()), + toolCall -> { + + var functionName = toolCall.name(); + + assertThat(functionName).isEqualTo("getWeatherInLocation"); + + String functionArguments = toolCall.arguments(); + + Map argumentsMap = getFunctionArguments(functionArguments); + + String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), + argumentsMap.get("unit").toString()); + + return functionResponse; + }); + + String response = responses.collectList() + .block() + .stream() + .map(cr -> cr.getResult().getOutput().getContent()) + .collect(Collectors.joining()); + + logger.info("Response: {}", response); + + assertThat(response).contains("30", "10", "15"); + + } + + @SuppressWarnings("unchecked") + private static Map getFunctionArguments(String functionArguments) { + try { + return new ObjectMapper().readValue(functionArguments, Map.class); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiApi chatCompletionApi() { + return new OpenAiApi(System.getenv("OPENAI_API_KEY")); + } + + @Bean + public OpenAiChatModel openAiClient(OpenAiApi openAiApi, List toolFunctionCallbacks) { + // enable the proxy tool calls option. + var options = OpenAiChatOptions.builder().withModel(DEFAULT_MODEL).withProxyToolCalls(true).build(); + + return new OpenAiChatModel(openAiApi, options, null, toolFunctionCallbacks, + RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP); + } + + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java similarity index 98% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index 6d8cb41d17c..ac395cf66cb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; +package org.springframework.ai.openai.chat.proxy; import static org.assertj.core.api.Assertions.assertThat; @@ -50,6 +50,8 @@ import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.openai.chat.OpenAiChatModelIT; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java similarity index 98% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java index ea3f603755a..892c865c2f3 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; +package org.springframework.ai.openai.chat.proxy; import static org.assertj.core.api.Assertions.assertThat; @@ -50,6 +50,8 @@ import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.openai.chat.OpenAiChatModelIT; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 1be9935e672..b7ea5051d3d 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -180,7 +180,8 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); - if (isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -209,7 +210,7 @@ public Flux stream(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); - if (isToolCall(chatResponse, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, Set.of(FinishReason.STOP.name(), FinishReason.FINISH_REASON_UNSPECIFIED.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 062089cb9d8..ce46d4f97c6 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -115,6 +115,8 @@ public enum TransportType { @JsonIgnore private boolean googleSearchRetrieval = false; + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on @@ -194,6 +196,11 @@ public Builder withGoogleSearchRetrieval(boolean googleSearch) { return this; } + public Builder withProxyToolCalls(boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public VertexAiGeminiChatOptions build() { return this.options; } @@ -321,6 +328,15 @@ public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) { this.googleSearchRetrieval = googleSearchRetrieval; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public boolean equals(Object o) { if (this == o) @@ -333,13 +349,13 @@ public boolean equals(Object o) { && Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model) && Objects.equals(responseMimeType, that.responseMimeType) && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(functions, that.functions); + && Objects.equals(functions, that.functions) && Objects.equals(proxyToolCalls, that.proxyToolCalls); } @Override public int hashCode() { return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model, - responseMimeType, functionCallbacks, functions, googleSearchRetrieval); + responseMimeType, functionCallbacks, functions, googleSearchRetrieval, proxyToolCalls); } @Override @@ -370,6 +386,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setFunctions(fromOptions.getFunctions()); options.setResponseMimeType(fromOptions.getResponseMimeType()); options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); + options.setProxyToolCalls(fromOptions.getProxyToolCalls()); return options; } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 2a9293c355a..6bbe65eb9bd 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -177,7 +177,7 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - if (isToolCall(chatResponse, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message @@ -241,7 +241,7 @@ public Flux stream(Prompt prompt) { return chatResponse.flatMap(response -> { - if (isToolCall(response, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index b495eb6679a..d33f04d79f3 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -123,6 +123,9 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions, ChatOptions { @NestedConfigurationProperty @JsonIgnore private Set functions = new HashSet<>(); + + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on public static Builder builder() { @@ -208,6 +211,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public ZhiPuAiChatOptions build() { return this.options; } @@ -346,6 +354,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public int hashCode() { final int prime = 31; @@ -358,6 +375,7 @@ public int hashCode() { result = prime * result + ((tools == null) ? 0 : tools.hashCode()); result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); result = prime * result + ((user == null) ? 0 : user.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); return result; } @@ -430,6 +448,12 @@ else if (!this.requestId.equals(other.requestId)) } else if (!this.doSample.equals(other.doSample)) return false; + if (this.proxyToolCalls == null) { + if (other.proxyToolCalls != null) + return false; + } + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) + return false; return true; } @@ -452,6 +476,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .withDoSample(fromOptions.getDoSample()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java index af05a1cfd33..903c304dc3c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java @@ -64,6 +64,7 @@ public static Builder from(AdvisedRequest from) { builder.systemParams = from.systemParams; builder.advisors = from.advisors; builder.advisorParams = from.advisorParams; + builder.advisorParams = from.advisorParams; return builder; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 67b81abeefe..dc7c7a5a2e5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -74,9 +74,9 @@ static Builder builder(ChatModel chatModel, ObservationRegistry observationRegis ChatClientRequestSpec prompt(); - ChatClientPromptRequestSpec prompt(String content); + ChatClientRequestSpec prompt(String content); - ChatClientPromptRequestSpec prompt(Prompt prompt); + ChatClientRequestSpec prompt(Prompt prompt); /** * Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 21785d36d08..13f80769a69 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -32,9 +32,9 @@ import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor.StreamResponseMode; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; @@ -85,12 +85,9 @@ public class DefaultChatClient implements ChatClient { private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION = new DefaultChatClientObservationConvention(); - private final ChatModel chatModel; - private final DefaultChatClientRequestSpec defaultChatClientRequest; - public DefaultChatClient(ChatModel chatModel, DefaultChatClientRequestSpec defaultChatClientRequest) { - this.chatModel = chatModel; + public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) { this.defaultChatClientRequest = defaultChatClientRequest; } @@ -100,13 +97,19 @@ public ChatClientRequestSpec prompt() { } @Override - public ChatClientPromptRequestSpec prompt(String content) { - return new DefaultChatClientPromptRequestSpec(this.chatModel, new Prompt(content)); + public ChatClientRequestSpec prompt(String content) { + return prompt(new Prompt(content)); } - @Override - public ChatClientPromptRequestSpec prompt(Prompt prompt) { - return new DefaultChatClientPromptRequestSpec(this.chatModel, prompt); + public ChatClientRequestSpec prompt(Prompt prompt) { + + DefaultChatClientRequestSpec spec = new DefaultChatClientRequestSpec(this.defaultChatClientRequest); + spec.messages(prompt.getInstructions()); + if (prompt.getOptions() != null) { + spec.options(prompt.getOptions()); + } + + return spec; } /** @@ -997,25 +1000,4 @@ public Flux content() { } - public static class DefaultChatClientPromptRequestSpec implements ChatClientPromptRequestSpec { - - private final ChatModel chatModel; - - private final Prompt prompt; - - public DefaultChatClientPromptRequestSpec(ChatModel chatModel, Prompt prompt) { - this.chatModel = chatModel; - this.prompt = prompt; - } - - public CallPromptResponseSpec call() { - return new DefaultCallPromptResponseSpec(this.chatModel, this.prompt); - } - - public StreamPromptResponseSpec stream() { - return new DefaultStreamPromptResponseSpec(this.chatModel, this.prompt); - } - - } - } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 4f22ad2c800..411f19e8988 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -50,8 +50,6 @@ public class DefaultChatClientBuilder implements Builder { protected final DefaultChatClientRequestSpec defaultRequest; - private final ChatModel chatModel; - DefaultChatClientBuilder(ChatModel chatModel) { this(chatModel, ObservationRegistry.NOOP, null); } @@ -60,14 +58,13 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); - this.chatModel = chatModel; this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention); } public ChatClient build() { - return new DefaultChatClient(this.chatModel, this.defaultRequest); + return new DefaultChatClient(this.defaultRequest); } public Builder defaultAdvisors(Advisor... advisor) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index 6085385646c..2b58ebf6144 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -75,18 +75,18 @@ protected AbstractToolCallSupport(FunctionCallbackContext functionCallbackContex } } - private static List merge(FunctionCallingOptions funcitonOptions, + private static List merge(FunctionCallingOptions functionOptions, List toolFunctionCallbacks) { List toolFunctionCallbacksCopy = new ArrayList<>(); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { toolFunctionCallbacksCopy.addAll(toolFunctionCallbacks); } - if (!CollectionUtils.isEmpty(funcitonOptions.getFunctionCallbacks())) { - toolFunctionCallbacksCopy.addAll(funcitonOptions.getFunctionCallbacks()); + if (!CollectionUtils.isEmpty(functionOptions.getFunctionCallbacks())) { + toolFunctionCallbacksCopy.addAll(functionOptions.getFunctionCallbacks()); // Make sure that that function callbacks are are registered directly to the // functionCallbackRegister and not passed in the default options. - funcitonOptions.setFunctionCallbacks(List.of()); + functionOptions.setFunctionCallbacks(List.of()); } return toolFunctionCallbacksCopy; } @@ -220,6 +220,13 @@ protected boolean isToolCall(ChatResponse chatResponse, Set toolCallFini return generations.stream().anyMatch(g -> isToolCall(g, toolCallFinishReasons)); } + /** + * Check if the generation is a tool call. The tool call finish reasons are used to + * determine if the generation is a tool call. + * @param generation the generation to check. + * @param toolCallFinishReasons the tool call finish reasons to check. + * @return true if the generation is a tool call, false otherwise. + */ protected boolean isToolCall(Generation generation, Set toolCallFinishReasons) { var finishReason = (generation.getMetadata().getFinishReason() != null) ? generation.getMetadata().getFinishReason() : ""; @@ -229,4 +236,26 @@ protected boolean isToolCall(Generation generation, Set toolCallFinishRe .contains(finishReason.toLowerCase()); } + /** + * Check if the proxyToolCalls is enabled for the given prompt or the default tool + * call options. The prompt options take precedence over the default options. When the + * proxyToolCalls is enabled the ChatModel implementation will not handle the function + * calling internally. The tool call and tool response messages are exposed outside + * the ChatModel implementation. + * @param prompt the prompt to check. + * @param defaultOptions the default tool call options to check. + * @return true if the proxyToolCalls is enabled, false otherwise. + */ + protected boolean isProxyToolCalls(Prompt prompt, FunctionCallingOptions defaultOptions) { + if (prompt.getOptions() instanceof FunctionCallingOptions functionCallOptions + && functionCallOptions.getProxyToolCalls() != null) { + return functionCallOptions.getProxyToolCalls(); + } + else if (defaultOptions.getProxyToolCalls() != null) { + return defaultOptions.getProxyToolCalls(); + } + + return false; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index feb8e067569..3d36b1bfbf4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -50,7 +50,7 @@ public Prompt(Message message) { } public Prompt(List messages) { - this.messages = messages; + this(messages, null); } public Prompt(String contents, ChatOptions chatOptions) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index df603d2b38f..f953e907d33 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -54,6 +54,16 @@ public interface FunctionCallingOptions { */ void setFunctions(Set functions); + default Boolean getProxyToolCalls() { + return false; + } + + default void setProxyToolCalls(Boolean proxyToolCalls) { + if (proxyToolCalls != null) { + throw new UnsupportedOperationException("Setting Proxy Tool Calls are not supported!"); + } + } + /** * @return Returns FunctionCallingOptionsBuilder to create a new instance of * FunctionCallingOptions. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index 5c1d7c05201..04c66c13d44 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -102,6 +102,11 @@ public FunctionCallingOptionsBuilder withTopP(Double topP) { return this; } + public FunctionCallingOptionsBuilder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.setProxyToolCalls(proxyToolCalls); + return this; + } + public PortableFunctionCallingOptions build() { return this.options; } @@ -128,6 +133,12 @@ public static class PortableFunctionCallingOptions implements FunctionCallingOpt private Double topP; + private Boolean proxyToolCalls = false; + + public static FunctionCallingOptionsBuilder builder() { + return new FunctionCallingOptionsBuilder(); + } + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -220,6 +231,15 @@ public void setTopP(Double topP) { this.topP = topP; } + @Override + public Boolean getProxyToolCalls() { + return proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public ChatOptions copy() { return new FunctionCallingOptionsBuilder().withModel(this.model) @@ -232,6 +252,7 @@ public ChatOptions copy() { .withTopP(this.topP) .withFunctions(this.functions) .withFunctionCallbacks(this.functionCallbacks) + .withProxyToolCalls(this.proxyToolCalls) .build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java new file mode 100644 index 00000000000..f3f23868d5c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java @@ -0,0 +1,166 @@ +package org.springframework.ai.model.function; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.util.CollectionUtils; + +import reactor.core.publisher.Flux; + +/** + * Helper class that reuses the {@link AbstractToolCallSupport} to implement the function + * call handling logic on the client side. Used when the withProxyToolCalls(true) option + * is enabled. + */ +public class ToolCallHelper extends AbstractToolCallSupport { + + /** + * Helper used to provide only the function definition, without the actual function + * call implementation. + */ + public static record FunctionDefinition(String name, String description, + String inputTypeSchema) implements FunctionCallback { + + @Override + public String getName() { + return this.name(); + } + + @Override + public String getDescription() { + return this.description(); + } + + @Override + public String getInputTypeSchema() { + return this.inputTypeSchema(); + } + + @Override + public String call(String functionInput) { + throw new UnsupportedOperationException( + "FunctionDefinition provides only metadata. It doesn't implement the call method."); + } + + } + + public ToolCallHelper() { + this(null, PortableFunctionCallingOptions.builder().build(), List.of()); + } + + public ToolCallHelper(FunctionCallbackContext functionCallbackContext, + FunctionCallingOptions functionCallingOptions, List toolFunctionCallbacks) { + super(functionCallbackContext, functionCallingOptions, toolFunctionCallbacks); + } + + @Override + public boolean isToolCall(ChatResponse chatResponse, Set toolCallFinishReasons) { + return super.isToolCall(chatResponse, toolCallFinishReasons); + } + + @Override + public List buildToolCallConversation(List previousMessages, AssistantMessage assistantMessage, + ToolResponseMessage toolResponseMessage) { + return super.buildToolCallConversation(previousMessages, assistantMessage, toolResponseMessage); + } + + @Override + public List handleToolCalls(Prompt prompt, ChatResponse response) { + return super.handleToolCalls(prompt, response); + } + + public Flux processStream(ChatModel chatModel, Prompt prompt, Set finishReasons, + Function customFunction) { + + Flux chatResponses = chatModel.stream(prompt); + + return chatResponses.flatMap(chatResponse -> { + + boolean isToolCall = this.isToolCall(chatResponse, finishReasons); + + if (isToolCall) { + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + String functionResponse = customFunction.apply(toolCall); + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), + ModelOptionsUtils.toJsonString(functionResponse))); + } + + ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); + + List toolCallConversation = this.buildToolCallConversation(prompt.getInstructions(), + assistantMessage, toolMessageResponse); + + var prompt2 = new Prompt(toolCallConversation, prompt.getOptions()); + + return processStream(chatModel, prompt2, finishReasons, customFunction); + } + + return Flux.just(chatResponse); + }); + } + + public ChatResponse processCall(ChatModel chatModel, Prompt prompt, Set finishReasons, + Function customFunction) { + + ChatResponse chatResponse = chatModel.call(prompt); + + boolean isToolCall = this.isToolCall(chatResponse, finishReasons); + + if (!isToolCall) { + return chatResponse; + } + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + String functionResponse = customFunction.apply(toolCall); + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), + ModelOptionsUtils.toJsonString(functionResponse))); + } + + ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); + + List toolCallConversation = this.buildToolCallConversation(prompt.getInstructions(), assistantMessage, + toolMessageResponse); + + var prompt2 = new Prompt(toolCallConversation, prompt.getOptions()); + + return processCall(chatModel, prompt2, finishReasons, customFunction); + } + +} \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index a5fcc2fd001..f4191093f8b 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -103,8 +103,9 @@ The prefix `spring.ai.anthropic.chat` is the property prefix that lets you confi | spring.ai.anthropic.chat.options.stop-sequence | Custom text sequences that will cause the model to stop generating. Our models will normally stop when they have naturally completed their turn, which will result in a response stop_reason of "end_turn". If you want the model to stop generating when it encounters custom strings of text, you can use the stop_sequences parameter. If the model encounters one of the custom sequences, the response stop_reason value will be "stop_sequence" and the response stop_sequence value will contain the matched stop sequence. | - | spring.ai.anthropic.chat.options.top-p | Use nucleus sampling. In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. Recommended for advanced use cases only. You usually only need to use temperature. | - | spring.ai.anthropic.chat.options.top-k | Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Learn more technical details here. Recommended for advanced use cases only. You usually only need to use temperature. | - -| spring.ai.mistralai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - -| spring.ai.mistralai.chat.options.functionCallbacks | MistralAI Tool Function Callbacks to register with the ChatModel. | - +| spring.ai.anthropic.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - +| spring.ai.anthropic.chat.options.functionCallbacks | Tool Function Callbacks to register with the ChatModel. | - +| spring.ai.anthropic.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.anthropic.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index 92e7e3e5a72..63dafabb60e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -143,6 +143,7 @@ Deployments model name to provide as part of this completions request. | gpt-4o | spring.ai.azure.openai.chat.options.presencePenalty | A value that influences the probability of generated tokens appearing based on their existing presence in generated text. Positive values will make tokens less likely to appear when they already exist and increase the model's likelihood to output new topics. | - | spring.ai.azure.openai.chat.options.responseFormat | An object specifying the format that the model must output. Using `AzureOpenAiResponseFormat.JSON` enables JSON mode, which guarantees the message the model generates is valid JSON. Using AzureOpenAiResponseFormat.TEXT enables TEXT mode.| - | spring.ai.azure.openai.chat.options.frequencyPenalty | A value that influences the probability of generated tokens appearing based on their cumulative frequency in generated text. Positive values will make tokens less likely to appear as their frequency increases and decrease the likelihood of the model repeating the same statements verbatim. | - +| spring.ai.azure.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.azure.openai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc index 76a733f0fdc..c4d68799f11 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc @@ -125,6 +125,7 @@ The prefix `spring.ai.openai.chat` is the property prefix that lets you configur | spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | - | spring.ai.openai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - | spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false +| spring.ai.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.openai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc index 9e0a6e0471f..5d2680a493d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc @@ -102,6 +102,7 @@ The prefix `spring.ai.mistralai.chat` is the property prefix that lets you confi | spring.ai.mistralai.chat.options.toolChoice | Controls which (if any) function is called by the model. `none` means the model will not call a function and instead generates a message. `auto` means the model can pick between generating a message or calling a function. Specifying a particular function via `{"type: "function", "function": {"name": "my_function"}}` forces the model to call that function. `none` is the default when no functions are present. `auto` is the default if functions are present. | - | spring.ai.mistralai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - | spring.ai.mistralai.chat.options.functionCallbacks | Mistral AI Tool Function Callbacks to register with the ChatModel. | - +| spring.ai.mistralai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== NOTE: You can override the common `spring.ai.mistralai.base-url` and `spring.ai.mistralai.api-key` for the `ChatModel` and `EmbeddingModel` implementations. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc index 56b135a72b5..ee7a02033a2 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc @@ -101,6 +101,7 @@ The prefix `spring.ai.openai.chat` is the property prefix that lets you configur | spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | - | spring.ai.openai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - | spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false +| spring.ai.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.openai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index ae73ee916cf..d06b148b248 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -105,6 +105,7 @@ The remaining `options` properties are based on the link:https://github.com/olla | spring.ai.ollama.chat.options.penalize-newline | - | true | spring.ai.ollama.chat.options.stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. | - | spring.ai.ollama.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - +| spring.ai.ollama.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.ollama.chat.options` can be overridden at runtime by adding request-specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index 1e731a770e7..f5c783b4615 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -117,6 +117,7 @@ The `JSON_SCHEMA` type enables link:https://platform.openai.com/docs/guides/stru | spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false | spring.ai.openai.chat.options.parallel-tool-calls | Whether to enable link:https://platform.openai.com/docs/guides/function-calling/parallel-function-calling[parallel function calling] during tool use. | true | spring.ai.openai.chat.options.http-headers | Optional HTTP headers to be added to the chat completion request. To override the `api-key` you need to use an `Authorization` header key, and you have to prefix the key value with the `Bearer ` prefix. | - +| spring.ai.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== NOTE: You can override the common `spring.ai.openai.base-url` and `spring.ai.openai.api-key` for the `ChatModel` and `EmbeddingModel` implementations. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index 73dc4b2a6b1..dbbcdabba7b 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -78,6 +78,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo | spring.ai.vertex.ai.gemini.chat.options.frequencyPenalty | | - | spring.ai.vertex.ai.gemini.chat.options.presencePenalty | | - | spring.ai.vertex.ai.gemini.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - +| spring.ai.vertex.ai.gemini.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc index d946274232d..328dc9a2bf1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc @@ -98,6 +98,7 @@ The prefix `spring.ai.zhipuai.chat` is the property prefix that lets you configu | spring.ai.zhipuai.chat.options.user | A unique identifier representing your end-user, which can help ZhiPuAI to monitor and detect abuse. | - | spring.ai.zhipuai.chat.options.requestId | The parameter is passed by the client and must ensure uniqueness. It is used to distinguish the unique identifier for each request. If the client does not provide it, the platform will generate it by default. | - | spring.ai.zhipuai.chat.options.doSample | When do_sample is set to true, the sampling strategy is enabled. If do_sample is false, the sampling strategy parameters temperature and top_p will not take effect. | true +| spring.ai.zhipuai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== NOTE: You can override the common `spring.ai.zhipuai.base-url` and `spring.ai.zhipuai.api-key` for the `ChatModel` implementations.