From 7e4b90dc48925e6a677e152156792b0cc14f68ff Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Tue, 11 Feb 2025 18:04:20 +0000 Subject: [PATCH] AzureOpenAI - Adopt ToolCallingManager API - Use the new ToolCallingManager API for AzureOpenAI chat model - Add Builder to construct AzureOpenAI chat model instance - Deprecate existing constructors - Update documentation about the change Signed-off-by: Ilayaperumal Gopinathan --- .../ai/azure/openai/AzureOpenAiChatModel.java | 271 +++++++++++++++--- .../azure/openai/AzureOpenAiChatOptions.java | 171 ++++++++--- .../AzureChatCompletionsOptionsTests.java | 5 +- .../azure/openai/AzureOpenAiChatClientIT.java | 7 +- .../azure/openai/AzureOpenAiChatModelIT.java | 7 +- .../AzureOpenAiChatModelObservationIT.java | 8 +- .../openai/AzureOpenAiChatModelTests.java | 97 ------- .../MockAzureOpenAiTestConfiguration.java | 2 +- .../AzureOpenAiChatModelFunctionCallIT.java | 6 +- .../azure-open-ai-chat-functions.adoc | 4 +- .../openai/AzureOpenAiAutoConfiguration.java | 21 +- 11 files changed, 400 insertions(+), 199 deletions(-) delete mode 100644 models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelTests.java diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index d527115da4f..060c8f305da 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -19,11 +19,9 @@ import java.util.ArrayList; import java.util.Base64; import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import com.azure.ai.openai.OpenAIAsyncClient; @@ -58,6 +56,8 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -88,7 +88,13 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -115,12 +121,16 @@ */ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel { + private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModel.class); + private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-4o"; private static final Double DEFAULT_TEMPERATURE = 0.7; private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + /** * The {@link OpenAIClient} used to interact with the Azure OpenAI service. */ @@ -146,6 +156,12 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + /** + * ToolCalling manager used for ToolCalling support. + */ + private final ToolCallingManager toolCallingManager; + + @Deprecated public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { this(openAIClientBuilder, AzureOpenAiChatOptions.builder() @@ -154,29 +170,52 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { .build()); } + @Deprecated public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options) { this(openAIClientBuilder, options, null); } + @Deprecated public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, FunctionCallbackResolver functionCallbackResolver) { this(openAIClientBuilder, options, functionCallbackResolver, List.of()); } + @Deprecated public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks) { + @Nullable FunctionCallbackResolver functionCallbackResolver, + @Nullable List toolFunctionCallbacks) { this(openAIClientBuilder, options, functionCallbackResolver, toolFunctionCallbacks, ObservationRegistry.NOOP); } + @Deprecated public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - ObservationRegistry observationRegistry) { - super(functionCallbackResolver, options, toolFunctionCallbacks); + @Nullable FunctionCallbackResolver functionCallbackResolver, + @Nullable List toolFunctionCallbacks, ObservationRegistry observationRegistry) { + this(openAIClientBuilder, options, + LegacyToolCallingManager.builder() + .functionCallbackResolver(functionCallbackResolver) + .functionCallbacks(toolFunctionCallbacks) + .build(), + observationRegistry); + logger.warn("This constructor is deprecated and will be removed in the next milestone. " + + "Please use the AzureOpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); + } + + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions defaultOptions, + ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry) { + // We do not pass the 'defaultOptions' to the AbstractToolSupport, + // because it modifies them. We are using ToolCallingManager instead, + // so we just pass empty options here. + super(null, AzureOpenAiChatOptions.builder().build(), List.of()); Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null"); - Assert.notNull(options, "AzureOpenAiChatOptions must not be null"); + Assert.notNull(defaultOptions, "defaultOptions cannot be null"); + Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); this.openAIClient = openAIClientBuilder.buildClient(); this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient(); - this.defaultOptions = options; + this.defaultOptions = defaultOptions; + this.toolCallingManager = toolCallingManager; this.observationRegistry = observationRegistry; } @@ -228,7 +267,10 @@ public AzureOpenAiChatOptions getDefaultOptions() { @Override public ChatResponse call(Prompt prompt) { - return this.internalCall(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalCall(requestPrompt, null); } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { @@ -252,12 +294,21 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (!isProxyToolCalls(prompt, this.defaultOptions) - && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the call method with the tool call message - // conversation that contains the call responses. - return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null + && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } } return response; @@ -265,7 +316,10 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS @Override public Flux stream(Prompt prompt) { - return this.internalStream(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalStream(requestPrompt, null); } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { @@ -344,12 +398,22 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }); return chatResponseFlux.flatMap(chatResponse -> { - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, - Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { - var toolCallConversation = handleToolCalls(prompt, chatResponse); - // Recursively call the call method with the tool call message - // conversation that contains the call responses. - return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) + && chatResponse.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder() + .from(chatResponse) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } + else { + // Send the tool execution result back to the model. + return this.internalStream( + new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + chatResponse); + } } Flux flux = Flux.just(chatResponse) @@ -447,7 +511,7 @@ private Generation buildGeneration(ChatChoice choice, Map metada */ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { - Set functionsForThisRequest = new HashSet<>(); + List functionsForThisRequest = new ArrayList<>(); List azureMessages = prompt.getInstructions() .stream() @@ -459,27 +523,27 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { options = this.merge(options, this.defaultOptions); - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { - functionsForThisRequest.addAll(this.defaultOptions.getFunctions()); - } - if (prompt.getOptions() != null) { AzureOpenAiChatOptions updatedRuntimeOptions; if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, AzureOpenAiChatOptions.class); } + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, + ToolCallingChatOptions.class, AzureOpenAiChatOptions.class); + } else { updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, AzureOpenAiChatOptions.class); } options = this.merge(updatedRuntimeOptions, options); - functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions)); + // Add the tool definitions to the request's tools parameter. + functionsForThisRequest.addAll(this.toolCallingManager.resolveToolDefinitions(updatedRuntimeOptions)); } // Add the enabled functions definitions to the request's tools parameter. - if (!CollectionUtils.isEmpty(functionsForThisRequest)) { List tools = this.getFunctionTools(functionsForThisRequest); List tools2 = tools.stream() @@ -491,14 +555,12 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { return options; } - private List getFunctionTools(Set functionNames) { - return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { - + private List getFunctionTools(List toolDefinitions) { + return toolDefinitions.stream().map(toolDefinition -> { ChatCompletionsFunctionToolDefinitionFunction functionDefinition = new ChatCompletionsFunctionToolDefinitionFunction( - functionCallback.getName()); - functionDefinition.setDescription(functionCallback.getDescription()); - BinaryData parameters = BinaryData - .fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())); + toolDefinition.name()); + functionDefinition.setDescription(toolDefinition.description()); + BinaryData parameters = BinaryData.fromObject(ModelOptionsUtils.jsonToMap(toolDefinition.inputSchema())); functionDefinition.setParameters(parameters); return new ChatCompletionsFunctionToolDefinition(functionDefinition); }).toList(); @@ -589,6 +651,53 @@ private List nullSafeList(List list) { return list != null ? list : Collections.emptyList(); } + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + AzureOpenAiChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + AzureOpenAiChatOptions.class); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, + AzureOpenAiChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + AzureOpenAiChatOptions.class); + } + } + + // Define request options by merging runtime options and default options + AzureOpenAiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + AzureOpenAiChatOptions.class); + + // Merge @JsonIgnore-annotated options explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(), + this.defaultOptions.isInternalToolExecutionEnabled())); + requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), + this.defaultOptions.getToolNames())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); + } + else { + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled()); + requestOptions.setToolNames(this.defaultOptions.getToolNames()); + requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); + } + + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + /** * Merges the Azure's {@link ChatCompletionsOptions} (fromAzureOptions) into the * Spring AI's {@link AzureOpenAiChatOptions} (toSpringAiOptions) and return a new @@ -842,4 +951,94 @@ public void setObservationConvention(ChatModelObservationConvention observationC this.observationConvention = observationConvention; } + public static Builder builder() { + return new Builder(); + } + + /** + * Builder to construct {@link AzureOpenAiChatModel}. + */ + public static class Builder { + + private OpenAIClientBuilder openAIClientBuilder; + + private AzureOpenAiChatOptions defaultOptions = AzureOpenAiChatOptions.builder() + .deploymentName(DEFAULT_DEPLOYMENT_NAME) + .temperature(DEFAULT_TEMPERATURE) + .build(); + + private ToolCallingManager toolCallingManager; + + private FunctionCallbackResolver functionCallbackResolver; + + private List toolFunctionCallbacks; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private Builder() { + } + + public Builder openAIClientBuilder(OpenAIClientBuilder openAIClientBuilder) { + this.openAIClientBuilder = openAIClientBuilder; + return this; + } + + public Builder defaultOptions(AzureOpenAiChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + @Deprecated + public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { + this.functionCallbackResolver = functionCallbackResolver; + return this; + } + + @Deprecated + public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { + this.toolFunctionCallbacks = toolFunctionCallbacks; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public AzureOpenAiChatModel build() { + if (toolCallingManager != null) { + Assert.isNull(functionCallbackResolver, + "functionCallbackResolver cannot be set when toolCallingManager is set"); + Assert.isNull(toolFunctionCallbacks, + "toolFunctionCallbacks cannot be set when toolCallingManager is set"); + + return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, toolCallingManager, + observationRegistry); + } + + if (functionCallbackResolver != null) { + Assert.isNull(toolCallingManager, + "toolCallingManager cannot be set when functionCallbackResolver is set"); + List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks + : List.of(); + + return new Builder().openAIClientBuilder(openAIClientBuilder) + .defaultOptions(defaultOptions) + .functionCallbackResolver(functionCallbackResolver) + .toolFunctionCallbacks(toolCallbacks) + .observationRegistry(observationRegistry) + .build(); + } + + return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, + observationRegistry); + } + + } + } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index d34ca5dfbd1..887c0ad6e74 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -17,6 +17,8 @@ package org.springframework.ai.azure.openai; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -30,7 +32,9 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -44,7 +48,7 @@ * @author Ilayaperumal Gopinathan */ @JsonInclude(Include.NON_NULL) -public class AzureOpenAiChatOptions implements FunctionCallingOptions { +public class AzureOpenAiChatOptions implements ToolCallingChatOptions { /** * The maximum number of tokens to generate. @@ -138,33 +142,6 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions { @JsonProperty("response_format") private AzureOpenAiResponseFormat responseFormat; - /** - * OpenAI Tool Function Callbacks to register with the ChatModel. For Prompt Options - * the functionCallbacks are automatically enabled for the duration of the prompt - * execution. For Default Options the functionCallbacks are registered but disabled by - * default. Use the enableFunctions to set the functions from the registry to be used - * by the ChatModel chat completion requests. - */ - @JsonIgnore - private List functionCallbacks = new ArrayList<>(); - - /** - * List of functions, identified by their names, to configure for function calling in - * the chat completion requests. Functions with those names must exist in the - * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions - * are automatically enabled for the duration of the prompt execution. - * - * Note that function enabled with the default options are enabled for all chat - * completion requests. This could impact the token count and the billing. If the - * functions is set in a prompt options, then the enabled functions are only active - * for the duration of this prompt execution. - */ - @JsonIgnore - private Set functions = new HashSet<>(); - - @JsonIgnore - private Boolean proxyToolCalls; - /** * Seed value for deterministic sampling such that the same seed and parameters return * the same result. @@ -199,7 +176,68 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions { private ChatCompletionStreamOptions streamOptions; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); + + /** + * Collection of {@link ToolCallback}s to be used for tool calling in the chat + * completion requests. + */ + @JsonIgnore + private List toolCallbacks = new ArrayList<>(); + + /** + * Collection of tool names to be resolved at runtime and used for tool calling in the + * chat completion requests. + */ + @JsonIgnore + private Set toolNames = new HashSet<>(); + + /** + * Whether to enable the tool execution lifecycle internally in ChatModel. + */ + @JsonIgnore + private Boolean internalToolExecutionEnabled; + + @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getToolNames() { + return this.toolNames; + } + + @Override + @JsonIgnore + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + this.toolNames = toolNames; + } + + @Override + @Nullable + @JsonIgnore + public Boolean isInternalToolExecutionEnabled() { + return internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } public static Builder builder() { return new Builder(); @@ -224,7 +262,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .topLogprobs(fromOptions.getTopLogProbs()) .enhancements(fromOptions.getEnhancements()) .toolContext(fromOptions.getToolContext()) + .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) .streamOptions(fromOptions.getStreamOptions()) + .toolCallbacks(fromOptions.getToolCallbacks()) + .toolNames(fromOptions.getToolNames()) .build(); } @@ -336,21 +377,28 @@ public void setTopP(Double topP) { } @Override + @Deprecated + @JsonIgnore public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.getToolCallbacks(); } + @Override + @Deprecated + @JsonIgnore public void setFunctionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; + this.setToolCallbacks(functionCallbacks); } @Override + @Deprecated + @JsonIgnore public Set getFunctions() { - return this.functions; + return this.getToolNames(); } public void setFunctions(Set functions) { - this.functions = functions; + this.setToolNames(functions); } public AzureOpenAiResponseFormat getResponseFormat() { @@ -400,12 +448,16 @@ public void setEnhancements(AzureChatEnhancementConfiguration enhancements) { } @Override + @Deprecated + @JsonIgnore public Boolean getProxyToolCalls() { - return this.proxyToolCalls; + return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; } + @Deprecated + @JsonIgnore public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; + this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; } @Override @@ -493,21 +545,19 @@ public Builder user(String user) { return this; } + @Deprecated public Builder functionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; + return toolCallbacks(functionCallbacks); } + @Deprecated public Builder functions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; + return toolNames(functionNames); } + @Deprecated public Builder function(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; + return toolNames(functionName); } public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) { @@ -515,8 +565,11 @@ public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) { return this; } + @Deprecated public Builder proxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; + if (proxyToolCalls != null) { + this.options.setInternalToolExecutionEnabled(!proxyToolCalls); + } return this; } @@ -555,6 +608,34 @@ public Builder streamOptions(ChatCompletionStreamOptions streamOptions) { return this; } + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); + return this; + } + + public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); + return this; + } + + public Builder toolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.setToolNames(toolNames); + return this; + } + + public Builder toolNames(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); + return this; + } + + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index e96544a1168..46dcf5547d4 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -71,7 +71,10 @@ public void createRequestWithChatOptions() { .responseFormat(AzureOpenAiResponseFormat.TEXT) .build(); - var client = new AzureOpenAiChatModel(mockClient, defaultOptions); + var client = AzureOpenAiChatModel.builder() + .openAIClientBuilder(mockClient) + .defaultOptions(defaultOptions) + .build(); var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content")); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index 2d2b644cdd9..189a51b235c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -161,9 +161,10 @@ public OpenAIClientBuilder openAIClient() { @Bean public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { - return new AzureOpenAiChatModel(openAIClientBuilder, - AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build()); - + return AzureOpenAiChatModel.builder() + .openAIClientBuilder(openAIClientBuilder) + .defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build()) + .build(); } @Bean diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index 34c1bb26b53..e6e17c26f0f 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -269,9 +269,10 @@ public OpenAIClientBuilder openAIClientBuilder() { @Bean public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { - return new AzureOpenAiChatModel(openAIClientBuilder, - AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build()); - + return AzureOpenAiChatModel.builder() + .openAIClientBuilder(openAIClientBuilder) + .defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build()) + .build(); } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java index 9310c746dc6..fbe010eb9ee 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java @@ -194,9 +194,11 @@ public OpenAIClientBuilder openAIClient() { @Bean public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, TestObservationRegistry observationRegistry) { - return new AzureOpenAiChatModel(openAIClientBuilder, - AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build(), null, List.of(), - observationRegistry); + return AzureOpenAiChatModel.builder() + .openAIClientBuilder(openAIClientBuilder) + .defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build()) + .observationRegistry(observationRegistry) + .build(); } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelTests.java deleted file mode 100644 index cb98213b331..00000000000 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelTests.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.azure.openai; - -import java.util.List; - -import com.azure.ai.openai.OpenAIClientBuilder; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; - -/** - * @author Jihoon Kim - */ -@ExtendWith(MockitoExtension.class) -public class AzureOpenAiChatModelTests { - - @Mock - OpenAIClientBuilder mockClient; - - @Mock - FunctionCallbackResolver functionCallbackResolver; - - @Test - public void createAzureOpenAiChatModelTest() { - String callbackFromChatOptions = "callbackFromChatOptions"; - String callbackFromConstructorParam = "callbackFromConstructorParam"; - - AzureOpenAiChatOptions chatOptions = AzureOpenAiChatOptions.builder() - .functionCallbacks(List.of(new TestFunctionCallback(callbackFromChatOptions))) - .build(); - - List functionCallbacks = List.of(new TestFunctionCallback(callbackFromConstructorParam)); - - AzureOpenAiChatModel openAiChatModel = new AzureOpenAiChatModel(this.mockClient, chatOptions, - this.functionCallbackResolver, functionCallbacks); - - assert 2 == openAiChatModel.getFunctionCallbackRegister().size(); - - assert callbackFromChatOptions == openAiChatModel.getFunctionCallbackRegister() - .get(callbackFromChatOptions) - .getName(); - - assert callbackFromConstructorParam == openAiChatModel.getFunctionCallbackRegister() - .get(callbackFromConstructorParam) - .getName(); - } - - private class TestFunctionCallback implements FunctionCallback { - - private final String name; - - TestFunctionCallback(String name) { - this.name = name; - } - - @Override - public String getName() { - return this.name; - } - - @Override - public String getDescription() { - return null; - } - - @Override - public String getInputTypeSchema() { - return null; - } - - @Override - public String call(String functionInput) { - return null; - } - - } - -} diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java index 1c0a84cade3..f70bef03dc1 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java @@ -59,7 +59,7 @@ OpenAIClientBuilder microsoftAzureOpenAiClient(MockWebServer webServer) { @Bean AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder microsoftAzureOpenAiClient) { - return new AzureOpenAiChatModel(microsoftAzureOpenAiClient); + return AzureOpenAiChatModel.builder().openAIClientBuilder(microsoftAzureOpenAiClient).build(); } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index e88573d4fc5..f5c7cfc6330 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -228,8 +228,10 @@ public OpenAIClientBuilder openAIClient() { @Bean public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClient, String selectedModel) { - return new AzureOpenAiChatModel(openAIClient, - AzureOpenAiChatOptions.builder().deploymentName(selectedModel).maxTokens(500).build()); + return AzureOpenAiChatModel.builder() + .openAIClientBuilder(openAIClient) + .defaultOptions(AzureOpenAiChatOptions.builder().deploymentName(selectedModel).maxTokens(500).build()) + .build(); } @Bean diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc index d36337fa02a..882ec73e349 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc @@ -1,4 +1,6 @@ -= Azure OpenAI Function Calling += Azure OpenAI Function Calling (Deprecated) + +WARNING: This page describes the previous version of the Function Calling API, which has been deprecated and marked for remove in the next release. The current version is available at xref:api/tools.adoc[Tool Calling]. See the xref:api/tools-migration.adoc[Migration Guide] for more information. Function calling lets developers create a description of a function in their code, then pass that description to a language model in a request. The response from the model includes the name of a function that matches the description and the arguments to call it with. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 08d038c9d2b..ae3b77efcc7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -28,6 +28,7 @@ import com.azure.core.util.Header; import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; @@ -35,10 +36,11 @@ import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -55,12 +57,14 @@ * @author Piotr Olaszewski * @author Soby Chacko * @author Manuel Andreo Garcia + * @author Ilayaperumal Gopinathan */ -@AutoConfiguration +@AutoConfiguration(after = { ToolCallingAutoConfiguration.class }) @ConditionalOnClass({ OpenAIClientBuilder.class, AzureOpenAiChatModel.class }) @EnableConfigurationProperties({ AzureOpenAiChatProperties.class, AzureOpenAiEmbeddingProperties.class, AzureOpenAiConnectionProperties.class, AzureOpenAiImageOptionsProperties.class, AzureOpenAiAudioTranscriptionProperties.class }) +@ImportAutoConfiguration(classes = { ToolCallingAutoConfiguration.class }) public class AzureOpenAiAutoConfiguration { private static final String APPLICATION_ID = "spring-ai"; @@ -121,13 +125,16 @@ public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnection @ConditionalOnProperty(prefix = AzureOpenAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, - AzureOpenAiChatProperties chatProperties, List toolFunctionCallbacks, - FunctionCallbackResolver functionCallbackResolver, ObjectProvider observationRegistry, + AzureOpenAiChatProperties chatProperties, ToolCallingManager toolCallingManager, + ObjectProvider observationRegistry, ObjectProvider observationConvention) { - var chatModel = new AzureOpenAiChatModel(openAIClientBuilder, chatProperties.getOptions(), - functionCallbackResolver, toolFunctionCallbacks, - observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + var chatModel = AzureOpenAiChatModel.builder() + .openAIClientBuilder(openAIClientBuilder) + .defaultOptions(chatProperties.getOptions()) + .toolCallingManager(toolCallingManager) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); observationConvention.ifAvailable(chatModel::setObservationConvention); return chatModel;