Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
package org.springframework.ai.openai;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
Expand All @@ -33,9 +37,12 @@
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.ToolFunctionCallback;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
Expand All @@ -46,6 +53,7 @@
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
* {@link ChatClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
Expand All @@ -66,6 +74,8 @@ public class OpenAiChatClient implements ChatClient, StreamingChatClient {

private OpenAiChatOptions defaultOptions;

private Map<String, ToolFunctionCallback> toolCallbackRegister = new ConcurrentHashMap<>();

public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(OpenAiApiException.class)
Expand Down Expand Up @@ -108,18 +118,18 @@ public ChatResponse call(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, false);

ResponseEntity<ChatCompletion> completionEntity = this.openAiApi.chatCompletionEntity(request);
ResponseEntity<ChatCompletion> completionEntity = this.chatCompletionWithTools(request);

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
logger.warn("No chat completion returned for request: {}", prompt);
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);

List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
return new Generation(choice.message().content(), Map.of("role", choice.message().role().name()))
return new Generation(choice.message().content(), toMap(choice.message()))
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
}).toList();

Expand Down Expand Up @@ -162,6 +172,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
*/
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

Set<String> enabledFunctionsForRequest = new HashSet<>();

List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
Expand All @@ -170,14 +182,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);

if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, OpenAiChatOptions.class);

Set<String> promptEnabledFunctions = handleToolFunctionConfigurations(updatedRuntimeOptions, true,
true);
enabledFunctionsForRequest.addAll(promptEnabledFunctions);

request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
}
else {
Expand All @@ -186,7 +199,180 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
}
}

if (this.defaultOptions != null) {

Set<String> defaultEnabledFunctions = handleToolFunctionConfigurations(this.defaultOptions, false, false);

enabledFunctionsForRequest.addAll(defaultEnabledFunctions);

request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
}

// Add the enabled functions definitions to the request's tools parameter.
if (!CollectionUtils.isEmpty(enabledFunctionsForRequest)) {

if (stream) {
throw new IllegalArgumentException("Currently tool functions are not supported in streaming mode");
}

request = ModelOptionsUtils.merge(
OpenAiChatOptions.builder().withTools(this.getFunctionTools(enabledFunctionsForRequest)).build(),
request, ChatCompletionRequest.class);
}

return request;
}

private Set<String> handleToolFunctionConfigurations(OpenAiChatOptions options, boolean autoEnableCallbackFunctions,
boolean overrideCallbackFunctionsRegister) {

Set<String> enabledFunctions = new HashSet<>();

if (options != null) {
if (!CollectionUtils.isEmpty(options.getToolCallbacks())) {
options.getToolCallbacks().stream().forEach(toolCallback -> {

// Register the tool callback.
if (overrideCallbackFunctionsRegister) {
this.toolCallbackRegister.put(toolCallback.getName(), toolCallback);
}
else {
this.toolCallbackRegister.putIfAbsent(toolCallback.getName(), toolCallback);
}

// Automatically enable the function, usually from prompt callback.
if (autoEnableCallbackFunctions) {
enabledFunctions.add(toolCallback.getName());
}
});
}

// Add the explicitly enabled functions.
if (!CollectionUtils.isEmpty(options.getEnabledFunctions())) {
enabledFunctions.addAll(options.getEnabledFunctions());
}
}

return enabledFunctions;
}

/**
* @return returns the registered tool callbacks.
*/
Map<String, ToolFunctionCallback> getToolCallbackRegister() {
return toolCallbackRegister;
}

public List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {

List<OpenAiApi.FunctionTool> functionTools = new ArrayList<>();
for (String functionName : functionNames) {
if (!this.toolCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}
ToolFunctionCallback functionCallback = this.toolCallbackRegister.get(functionName);

var function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(),
functionCallback.getName(), functionCallback.getInputTypeSchema());
functionTools.add(new OpenAiApi.FunctionTool(function));
}

return functionTools;
}

/**
* Function Call handling. If the model calls a function, the function is called and
* the response is added to the conversation history. The conversation history is then
* sent back to the model.
* @param request the chat completion request
* @return the chat completion response.
*/
@SuppressWarnings("null")
private ResponseEntity<ChatCompletion> chatCompletionWithTools(OpenAiApi.ChatCompletionRequest request) {

ResponseEntity<ChatCompletion> chatCompletion = this.openAiApi.chatCompletionEntity(request);

// Return the result if the model is not calling a function.
if (!this.isToolCall(chatCompletion)) {
return chatCompletion;
}

// The OpenAI chat completion tool call API requires the complete conversation
// history. Including the initial user message.
List<ChatCompletionMessage> conversationMessages = new ArrayList<>(request.messages());

// We assume that the tool calling information is inside the response's first
// choice.
ChatCompletionMessage responseMessage = chatCompletion.getBody().choices().iterator().next().message();

if (chatCompletion.getBody().choices().size() > 1) {
logger.warn("More than one choice returned. Only the first choice is processed.");
}

// Add the assistant response to the message conversation history.
conversationMessages.add(responseMessage);

// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ToolCall toolCall : responseMessage.toolCalls()) {

var functionName = toolCall.function().name();
String functionArguments = toolCall.function().arguments();

if (!this.toolCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}

String functionResponse = this.toolCallbackRegister.get(functionName).call(functionArguments);

// Add the function response to the conversation.
conversationMessages.add(new ChatCompletionMessage(functionResponse, Role.TOOL, null, toolCall.id(), null));
}

// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationMessages, request.stream());
newRequest = ModelOptionsUtils.merge(newRequest, request, ChatCompletionRequest.class);

return this.chatCompletionWithTools(newRequest);
}

private Map<String, Object> toMap(ChatCompletionMessage message) {
Map<String, Object> map = new HashMap<>();

// The tool_calls and tool_call_id are not used by the OpenAiChatClient functions
// call support! Useful only for users that want to use the tool_calls and
// tool_call_id in their applications.
if (message.toolCalls() != null) {
map.put("tool_calls", message.toolCalls());
}
if (message.toolCallId() != null) {
map.put("tool_call_id", message.toolCallId());
}

if (message.role() != null) {
map.put("role", message.role().name());
}
return map;
}

/**
* Check if it is a model calls function response.
* @param chatCompletion the chat completion response.
* @return true if the model expects a function call.
*/
private Boolean isToolCall(ResponseEntity<ChatCompletion> chatCompletion) {
var body = chatCompletion.getBody();
if (body == null) {
return false;
}

var choices = body.choices();
if (CollectionUtils.isEmpty(choices)) {
return false;
}

return choices.get(0).message().toolCalls() != null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@

package org.springframework.ai.openai;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.chat.ChatOptions;
import org.springframework.ai.model.ToolFunctionCallback;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice;
import org.springframework.util.Assert;
import org.springframework.ai.openai.api.OpenAiApi.FunctionTool;

/**
Expand Down Expand Up @@ -114,6 +119,27 @@ public class OpenAiChatOptions implements ChatOptions {
* A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
*/
private @JsonProperty("user") String user;

/**
* OpenAI Tool Function Callbacks to register with the ChatClient.
* For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution.
* For Default Options the toolCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
* from the registry to be used by the ChatClient chat completion requests.
*/
@JsonIgnore
private List<ToolFunctionCallback> toolCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests.
* Functions with those names must exist in the toolCallbacks registry.
* The {@link #toolCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing.
* If the enabledFunctions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution.
*/
@JsonIgnore
private Set<String> enabledFunctions = new HashSet<>();
// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -202,6 +228,23 @@ public Builder withUser(String user) {
return this;
}

public Builder withToolCallbacks(List<ToolFunctionCallback> toolCallbacks) {
this.options.toolCallbacks = toolCallbacks;
return this;
}

public Builder withEnabledFunctions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.enabledFunctions = functionNames;
return this;
}

public Builder withEnabledFunction(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.enabledFunctions.add(functionName);
return this;
}

public OpenAiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -280,18 +323,22 @@ public void setStop(List<String> stop) {
this.stop = stop;
}

@Override
public Float getTemperature() {
return this.temperature;
}

@Override
public void setTemperature(Float temperature) {
this.temperature = temperature;
}

@Override
public Float getTopP() {
return this.topP;
}

@Override
public void setTopP(Float topP) {
this.topP = topP;
}
Expand Down Expand Up @@ -320,6 +367,24 @@ public void setUser(String user) {
this.user = user;
}

@Override
public List<ToolFunctionCallback> getToolCallbacks() {
return this.toolCallbacks;
}

@Override
public void setToolCallbacks(List<ToolFunctionCallback> toolCallbacks) {
this.toolCallbacks = toolCallbacks;
}

public Set<String> getEnabledFunctions() {
return enabledFunctions;
}

public void setEnabledFunctions(Set<String> functionNames) {
this.enabledFunctions = functionNames;
}

@Override
public int hashCode() {
final int prime = 31;
Expand Down
Loading