Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.ai.azure.openai;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -30,7 +32,9 @@
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -44,7 +48,7 @@
* @author Ilayaperumal Gopinathan
*/
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiChatOptions implements FunctionCallingOptions {
public class AzureOpenAiChatOptions implements ToolCallingChatOptions {

/**
* The maximum number of tokens to generate.
Expand Down Expand Up @@ -138,33 +142,6 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions {
@JsonProperty("response_format")
private AzureOpenAiResponseFormat responseFormat;

/**
* OpenAI Tool Function Callbacks to register with the ChatModel. For Prompt Options
* the functionCallbacks are automatically enabled for the duration of the prompt
* execution. For Default Options the functionCallbacks are registered but disabled by
* default. Use the enableFunctions to set the functions from the registry to be used
* by the ChatModel chat completion requests.
*/
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests. Functions with those names must exist in the
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
* are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat
* completion requests. This could impact the token count and the billing. If the
* functions is set in a prompt options, then the enabled functions are only active
* for the duration of this prompt execution.
*/
@JsonIgnore
private Set<String> functions = new HashSet<>();

@JsonIgnore
private Boolean proxyToolCalls;

/**
* Seed value for deterministic sampling such that the same seed and parameters return
* the same result.
Expand Down Expand Up @@ -199,7 +176,68 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions {
private ChatCompletionStreamOptions streamOptions;

@JsonIgnore
private Map<String, Object> toolContext;
private Map<String, Object> toolContext = new HashMap<>();

/**
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
* completion requests.
*/
@JsonIgnore
private List<FunctionCallback> toolCallbacks = new ArrayList<>();

/**
* Collection of tool names to be resolved at runtime and used for tool calling in the
* chat completion requests.
*/
@JsonIgnore
private Set<String> toolNames = new HashSet<>();

/**
* Whether to enable the tool execution lifecycle internally in ChatModel.
*/
@JsonIgnore
private Boolean internalToolExecutionEnabled;

@Override
@JsonIgnore
public List<FunctionCallback> getToolCallbacks() {
return this.toolCallbacks;
}

@Override
@JsonIgnore
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = toolCallbacks;
}

@Override
@JsonIgnore
public Set<String> getToolNames() {
return this.toolNames;
}

@Override
@JsonIgnore
public void setToolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
this.toolNames = toolNames;
}

@Override
@Nullable
@JsonIgnore
public Boolean isInternalToolExecutionEnabled() {
return internalToolExecutionEnabled;
}

@Override
@JsonIgnore
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

public static Builder builder() {
return new Builder();
Expand All @@ -224,7 +262,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.topLogprobs(fromOptions.getTopLogProbs())
.enhancements(fromOptions.getEnhancements())
.toolContext(fromOptions.getToolContext())
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
.streamOptions(fromOptions.getStreamOptions())
.toolCallbacks(fromOptions.getToolCallbacks())
.toolNames(fromOptions.getToolNames())
.build();
}

Expand Down Expand Up @@ -336,21 +377,28 @@ public void setTopP(Double topP) {
}

@Override
@Deprecated
@JsonIgnore
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
return this.getToolCallbacks();
}

@Override
@Deprecated
@JsonIgnore
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.functionCallbacks = functionCallbacks;
this.setToolCallbacks(functionCallbacks);
}

@Override
@Deprecated
@JsonIgnore
public Set<String> getFunctions() {
return this.functions;
return this.getToolNames();
}

public void setFunctions(Set<String> functions) {
this.functions = functions;
this.setToolNames(functions);
}

public AzureOpenAiResponseFormat getResponseFormat() {
Expand Down Expand Up @@ -400,12 +448,16 @@ public void setEnhancements(AzureChatEnhancementConfiguration enhancements) {
}

@Override
@Deprecated
@JsonIgnore
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
}

@Deprecated
@JsonIgnore
public void setProxyToolCalls(Boolean proxyToolCalls) {
this.proxyToolCalls = proxyToolCalls;
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
}

@Override
Expand Down Expand Up @@ -493,30 +545,31 @@ public Builder user(String user) {
return this;
}

@Deprecated
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
return toolCallbacks(functionCallbacks);
}

@Deprecated
public Builder functions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
return toolNames(functionNames);
}

@Deprecated
public Builder function(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
return toolNames(functionName);
}

public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) {
this.options.responseFormat = responseFormat;
return this;
}

@Deprecated
public Builder proxyToolCalls(Boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
if (proxyToolCalls != null) {
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
}
return this;
}

Expand Down Expand Up @@ -555,6 +608,34 @@ public Builder streamOptions(ChatCompletionStreamOptions streamOptions) {
return this;
}

public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
return this;
}

public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
return this;
}

public Builder toolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.setToolNames(toolNames);
return this;
}

public Builder toolNames(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.toolNames.addAll(Set.of(toolNames));
return this;
}

public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ public void createRequestWithChatOptions() {
.responseFormat(AzureOpenAiResponseFormat.TEXT)
.build();

var client = new AzureOpenAiChatModel(mockClient, defaultOptions);
var client = AzureOpenAiChatModel.builder()
.openAIClientBuilder(mockClient)
.defaultOptions(defaultOptions)
.build();

var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,10 @@ public OpenAIClientBuilder openAIClient() {

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
return new AzureOpenAiChatModel(openAIClientBuilder,
AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build());

return AzureOpenAiChatModel.builder()
.openAIClientBuilder(openAIClientBuilder)
.defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build())
.build();
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ public OpenAIClientBuilder openAIClientBuilder() {

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
return new AzureOpenAiChatModel(openAIClientBuilder,
AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build());

return AzureOpenAiChatModel.builder()
.openAIClientBuilder(openAIClientBuilder)
.defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build())
.build();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,11 @@ public OpenAIClientBuilder openAIClient() {
@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder,
TestObservationRegistry observationRegistry) {
return new AzureOpenAiChatModel(openAIClientBuilder,
AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build(), null, List.of(),
observationRegistry);
return AzureOpenAiChatModel.builder()
.openAIClientBuilder(openAIClientBuilder)
.defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build())
.observationRegistry(observationRegistry)
.build();
}

}
Expand Down
Loading