diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 09c864fb18e..e5ef62996c9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -16,8 +16,12 @@ package org.springframework.ai.openai; import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; @@ -33,9 +37,12 @@ import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.ToolFunctionCallback; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; @@ -46,6 +53,7 @@ import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * {@link ChatClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}. @@ -66,6 +74,8 @@ public class OpenAiChatClient implements ChatClient, StreamingChatClient { private OpenAiChatOptions defaultOptions; + private Map toolCallbackRegister = new ConcurrentHashMap<>(); + public final RetryTemplate retryTemplate = RetryTemplate.builder() .maxAttempts(10) .retryOn(OpenAiApiException.class) @@ -108,18 +118,18 @@ public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); - ResponseEntity completionEntity = this.openAiApi.chatCompletionEntity(request); + ResponseEntity completionEntity = this.chatCompletionWithTools(request); var chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - logger.warn("No chat completion returned for request: {}", prompt); + logger.warn("No chat completion returned for prompt: {}", prompt); return new ChatResponse(List.of()); } RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); List generations = chatCompletion.choices().stream().map(choice -> { - return new Generation(choice.message().content(), Map.of("role", choice.message().role().name())) + return new Generation(choice.message().content(), toMap(choice.message())) .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); }).toList(); @@ -162,6 +172,8 @@ public Flux stream(Prompt prompt) { */ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + Set enabledFunctionsForRequest = new HashSet<>(); + List chatCompletionMessages = prompt.getInstructions() .stream() .map(m -> new ChatCompletionMessage(m.getContent(), @@ -170,14 +182,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); - if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); - } - if (prompt.getOptions() != null) { if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class, OpenAiChatOptions.class); + + Set promptEnabledFunctions = handleToolFunctionConfigurations(updatedRuntimeOptions, true, + true); + enabledFunctionsForRequest.addAll(promptEnabledFunctions); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); } else { @@ -186,7 +199,180 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { } } + if (this.defaultOptions != null) { + + Set defaultEnabledFunctions = handleToolFunctionConfigurations(this.defaultOptions, false, false); + + enabledFunctionsForRequest.addAll(defaultEnabledFunctions); + + request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); + } + + // Add the enabled functions definitions to the request's tools parameter. + if (!CollectionUtils.isEmpty(enabledFunctionsForRequest)) { + + if (stream) { + throw new IllegalArgumentException("Currently tool functions are not supported in streaming mode"); + } + + request = ModelOptionsUtils.merge( + OpenAiChatOptions.builder().withTools(this.getFunctionTools(enabledFunctionsForRequest)).build(), + request, ChatCompletionRequest.class); + } + return request; } + private Set handleToolFunctionConfigurations(OpenAiChatOptions options, boolean autoEnableCallbackFunctions, + boolean overrideCallbackFunctionsRegister) { + + Set enabledFunctions = new HashSet<>(); + + if (options != null) { + if (!CollectionUtils.isEmpty(options.getToolCallbacks())) { + options.getToolCallbacks().stream().forEach(toolCallback -> { + + // Register the tool callback. + if (overrideCallbackFunctionsRegister) { + this.toolCallbackRegister.put(toolCallback.getName(), toolCallback); + } + else { + this.toolCallbackRegister.putIfAbsent(toolCallback.getName(), toolCallback); + } + + // Automatically enable the function, usually from prompt callback. + if (autoEnableCallbackFunctions) { + enabledFunctions.add(toolCallback.getName()); + } + }); + } + + // Add the explicitly enabled functions. + if (!CollectionUtils.isEmpty(options.getEnabledFunctions())) { + enabledFunctions.addAll(options.getEnabledFunctions()); + } + } + + return enabledFunctions; + } + + /** + * @return returns the registered tool callbacks. + */ + Map getToolCallbackRegister() { + return toolCallbackRegister; + } + + public List getFunctionTools(Set functionNames) { + + List functionTools = new ArrayList<>(); + for (String functionName : functionNames) { + if (!this.toolCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("No function callback found for function name: " + functionName); + } + ToolFunctionCallback functionCallback = this.toolCallbackRegister.get(functionName); + + var function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(), + functionCallback.getName(), functionCallback.getInputTypeSchema()); + functionTools.add(new OpenAiApi.FunctionTool(function)); + } + + return functionTools; + } + + /** + * Function Call handling. If the model calls a function, the function is called and + * the response is added to the conversation history. The conversation history is then + * sent back to the model. + * @param request the chat completion request + * @return the chat completion response. + */ + @SuppressWarnings("null") + private ResponseEntity chatCompletionWithTools(OpenAiApi.ChatCompletionRequest request) { + + ResponseEntity chatCompletion = this.openAiApi.chatCompletionEntity(request); + + // Return the result if the model is not calling a function. + if (!this.isToolCall(chatCompletion)) { + return chatCompletion; + } + + // The OpenAI chat completion tool call API requires the complete conversation + // history. Including the initial user message. + List conversationMessages = new ArrayList<>(request.messages()); + + // We assume that the tool calling information is inside the response's first + // choice. + ChatCompletionMessage responseMessage = chatCompletion.getBody().choices().iterator().next().message(); + + if (chatCompletion.getBody().choices().size() > 1) { + logger.warn("More than one choice returned. Only the first choice is processed."); + } + + // Add the assistant response to the message conversation history. + conversationMessages.add(responseMessage); + + // Every tool-call item requires a separate function call and a response (TOOL) + // message. + for (ToolCall toolCall : responseMessage.toolCalls()) { + + var functionName = toolCall.function().name(); + String functionArguments = toolCall.function().arguments(); + + if (!this.toolCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("No function callback found for function name: " + functionName); + } + + String functionResponse = this.toolCallbackRegister.get(functionName).call(functionArguments); + + // Add the function response to the conversation. + conversationMessages.add(new ChatCompletionMessage(functionResponse, Role.TOOL, null, toolCall.id(), null)); + } + + // Recursively call chatCompletionWithTools until the model doesn't call a + // functions anymore. + ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationMessages, request.stream()); + newRequest = ModelOptionsUtils.merge(newRequest, request, ChatCompletionRequest.class); + + return this.chatCompletionWithTools(newRequest); + } + + private Map toMap(ChatCompletionMessage message) { + Map map = new HashMap<>(); + + // The tool_calls and tool_call_id are not used by the OpenAiChatClient functions + // call support! Useful only for users that want to use the tool_calls and + // tool_call_id in their applications. + if (message.toolCalls() != null) { + map.put("tool_calls", message.toolCalls()); + } + if (message.toolCallId() != null) { + map.put("tool_call_id", message.toolCallId()); + } + + if (message.role() != null) { + map.put("role", message.role().name()); + } + return map; + } + + /** + * Check if it is a model calls function response. + * @param chatCompletion the chat completion response. + * @return true if the model expects a function call. + */ + private Boolean isToolCall(ResponseEntity chatCompletion) { + var body = chatCompletion.getBody(); + if (body == null) { + return false; + } + + var choices = body.choices(); + if (CollectionUtils.isEmpty(choices)) { + return false; + } + + return choices.get(0).message().toolCalls() != null; + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index a7b0c7e0d5e..eedcdeb2961 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -16,8 +16,11 @@ package org.springframework.ai.openai; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -25,8 +28,10 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.ChatOptions; +import org.springframework.ai.model.ToolFunctionCallback; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice; +import org.springframework.util.Assert; import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; /** @@ -114,6 +119,27 @@ public class OpenAiChatOptions implements ChatOptions { * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. */ private @JsonProperty("user") String user; + + /** + * OpenAI Tool Function Callbacks to register with the ChatClient. + * For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution. + * For Default Options the toolCallbacks are registered but disabled by default. Use the enableFunctions to set the functions + * from the registry to be used by the ChatClient chat completion requests. + */ + @JsonIgnore + private List toolCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. + * Functions with those names must exist in the toolCallbacks registry. + * The {@link #toolCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. + * If the enabledFunctions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. + */ + @JsonIgnore + private Set enabledFunctions = new HashSet<>(); // @formatter:on public static Builder builder() { @@ -202,6 +228,23 @@ public Builder withUser(String user) { return this; } + public Builder withToolCallbacks(List toolCallbacks) { + this.options.toolCallbacks = toolCallbacks; + return this; + } + + public Builder withEnabledFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.enabledFunctions = functionNames; + return this; + } + + public Builder withEnabledFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.enabledFunctions.add(functionName); + return this; + } + public OpenAiChatOptions build() { return this.options; } @@ -280,18 +323,22 @@ public void setStop(List stop) { this.stop = stop; } + @Override public Float getTemperature() { return this.temperature; } + @Override public void setTemperature(Float temperature) { this.temperature = temperature; } + @Override public Float getTopP() { return this.topP; } + @Override public void setTopP(Float topP) { this.topP = topP; } @@ -320,6 +367,24 @@ public void setUser(String user) { this.user = user; } + @Override + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + public void setToolCallbacks(List toolCallbacks) { + this.toolCallbacks = toolCallbacks; + } + + public Set getEnabledFunctions() { + return enabledFunctions; + } + + public void setEnabledFunctions(Set functionNames) { + this.enabledFunctions = functionNames; + } + @Override public int hashCode() { final int prime = 31; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 097234a04ff..cfae84b3a20 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -400,8 +400,6 @@ public record ResponseFormat( * and null otherwise. * @param toolCalls The tool calls generated by the model, such as function calls. Applicable only for * {@link Role#ASSISTANT} role and null otherwise. - * @param functionCall Deprecated and replaced by tool_calls. The name and arguments of a function that should be - * called, as generated by the model. */ @JsonInclude(Include.NON_NULL) public record ChatCompletionMessage( @@ -409,8 +407,7 @@ public record ChatCompletionMessage( @JsonProperty("role") Role role, @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, - @JsonProperty("tool_calls") List toolCalls, - @JsonProperty("function_call") ChatCompletionFunction functionCall) { + @JsonProperty("tool_calls") List toolCalls) { /** * Create a chat completion message with the given content and role. All other fields are null. @@ -418,7 +415,7 @@ public record ChatCompletionMessage( * @param role The role of the author of this message. */ public ChatCompletionMessage(String content, Role role) { - this(content, role, null, null, null, null); + this(content, role, null, null, null); } /** @@ -798,7 +795,7 @@ public ResponseEntity> embeddings(EmbeddingRequest< }); } - private static Map parseJson(String jsonSchema) { + public static Map parseJson(String jsonSchema) { try { return new ObjectMapper().readValue(jsonSchema, new TypeReference>() { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 467c11451a7..1856022021d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -16,10 +16,16 @@ package org.springframework.ai.openai; +import java.util.List; + import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.AbstractToolFunctionCallback; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.chat.api.tool.MockWeatherService; +import org.springframework.ai.openai.chat.api.tool.MockWeatherService.Request; +import org.springframework.ai.openai.chat.api.tool.MockWeatherService.Response; import static org.assertj.core.api.Assertions.assertThat; @@ -52,4 +58,99 @@ public void createRequestWithChatOptions() { assertThat(request.temperature()).isEqualTo(99.9f); } + @Test + public void promptOptionsTools() { + + final String TOOL_FUNCTION_NAME = "CurrentWeather"; + + var client = new OpenAiChatClient(new OpenAiApi("TEST")) + .withDefaultOptions(OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").build()); + + var request = client.createRequest(new Prompt("Test message content", OpenAiChatOptions.builder() + .withModel("PROMPT_MODEL") + .withToolCallbacks( + List.of(new AbstractToolFunctionCallback( + TOOL_FUNCTION_NAME, "Get the weather in location", MockWeatherService.Request.class) { + @Override + public Response apply(Request request) { + return new MockWeatherService().apply(request); + } + })) + .build()), false); + + assertThat(client.getToolCallbackRegister()).hasSize(1); + assertThat(client.getToolCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + + assertThat(request.tools()).hasSize(1); + assertThat(request.tools().get(0).function().name()).isEqualTo(TOOL_FUNCTION_NAME); + } + + @Test + public void defaultOptionsTools() { + + final String TOOL_FUNCTION_NAME = "CurrentWeather"; + + var client = new OpenAiChatClient(new OpenAiApi("TEST")).withDefaultOptions(OpenAiChatOptions.builder() + .withModel("DEFAULT_MODEL") + .withToolCallbacks( + List.of(new AbstractToolFunctionCallback( + TOOL_FUNCTION_NAME, "Get the weather in location", MockWeatherService.Request.class) { + @Override + public Response apply(Request request) { + return new MockWeatherService().apply(request); + } + })) + .build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(client.getToolCallbackRegister()).hasSize(1); + assertThat(client.getToolCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + assertThat(client.getToolCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) + .isEqualTo("Get the weather in location"); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + + assertThat(request.tools()).as("Default Options callback functions are not automatically enabled!") + .isNullOrEmpty(); + + // Explicitly enable the function + request = client.createRequest(new Prompt("Test message content", + OpenAiChatOptions.builder().withEnabledFunction(TOOL_FUNCTION_NAME).build()), false); + + assertThat(request.tools()).hasSize(1); + assertThat(request.tools().get(0).function().name()).as("Explicitly enabled function") + .isEqualTo(TOOL_FUNCTION_NAME); + + // Override the default options function with one from the prompt + request = client + .createRequest(new Prompt("Test message content", + OpenAiChatOptions.builder() + .withToolCallbacks(List + .of(new AbstractToolFunctionCallback(TOOL_FUNCTION_NAME, + "Overridden function description", MockWeatherService.Request.class) { + @Override + public String apply(Request request) { + return "Mock response"; + } + })) + .build()), + false); + + assertThat(request.tools()).hasSize(1); + assertThat(request.tools().get(0).function().name()).as("Explicitly enabled function") + .isEqualTo(TOOL_FUNCTION_NAME); + + assertThat(client.getToolCallbackRegister()).hasSize(1); + assertThat(client.getToolCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + assertThat(client.getToolCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) + .isEqualTo("Overridden function description"); + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java index 92d8a3b7cd5..8d558adb902 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java @@ -1,5 +1,6 @@ package org.springframework.ai.openai.chat; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -13,16 +14,19 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.model.AbstractToolFunctionCallback; +import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.chat.api.tool.MockWeatherService; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.convert.support.DefaultConversionService; @@ -160,4 +164,36 @@ void beanStreamOutputParserRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OpenAiChatOptions.builder() + .withModel("gpt-4-1106-preview") + .withToolCallbacks( + List.of(new AbstractToolFunctionCallback( + "getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, + (response) -> "" + response.temp() + response.unit()) { + + private final MockWeatherService weatherService = new MockWeatherService(); + + @Override + public MockWeatherService.Response apply(MockWeatherService.Request request) { + return weatherService.apply(request); + } + + })) + .build(); + + ChatResponse response = openAiChatClient.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30.0", "10.0", "15.0"); + + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java index f19ce9c5c80..883ce3a85b8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java @@ -19,7 +19,6 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClient.Builder; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java new file mode 100644 index 00000000000..62c2afe1354 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java @@ -0,0 +1,93 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.chat.api.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + c("metric"), + /** + * Fahrenheit. + */ + f("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.c); + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java new file mode 100644 index 00000000000..f1f6d874c93 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java @@ -0,0 +1,166 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.chat.api.tool; + +import java.util.ArrayList; +import java.util.List; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type; +import org.springframework.http.ResponseEntity; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Based on the OpenAI Function Calling tutorial: + * https://platform.openai.com/docs/guides/function-calling/parallel-function-calling + * + * @author Christian Tzolov + */ +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +public class OpenAiApiToolFunctionCallIT { + + private final Logger logger = LoggerFactory.getLogger(OpenAiApiToolFunctionCallIT.class); + + MockWeatherService weatherService = new MockWeatherService(); + + OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); + + @Test + public void toolFunctionCall() { + + // Step 1: send the conversation and available functions to the model + var message = new ChatCompletionMessage("What's the weather like in San Francisco, Tokyo, and Paris?", + Role.USER); + + var functionTool = new OpenAiApi.FunctionTool(Type.FUNCTION, + new OpenAiApi.FunctionTool.Function( + "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", + OpenAiApi.parseJson(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["c", "f"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """))); + + // Or you can use the + // ModelOptionsUtils.getJsonSchema(FakeWeatherService.Request.class))) to + // auto-generate the JSON schema like: + // var functionTool = new OpenAiApi.FunctionTool(Type.FUNCTION, new + // OpenAiApi.FunctionTool.Function( + // "Get the weather in location. Return temperature in 30°F or 30°C format.", + // "getCurrentWeather", + // ModelOptionsUtils.getJsonSchema(FakeWeatherService.Request.class))); + + List messages = new ArrayList<>(List.of(message)); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, "gpt-4-1106-preview", + List.of(functionTool), null); + + ResponseEntity chatCompletion = completionApi.chatCompletionEntity(chatCompletionRequest); + + assertThat(chatCompletion.getBody()).isNotNull(); + assertThat(chatCompletion.getBody().choices()).isNotEmpty(); + + ChatCompletionMessage responseMessage = chatCompletion.getBody().choices().get(0).message(); + + assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT); + assertThat(responseMessage.toolCalls()).isNotNull(); + + // Check if the model wanted to call a function + if (responseMessage.toolCalls() != null) { + + // extend conversation with assistant's reply. + messages.add(responseMessage); + + // Send the info for each function call and function response to the model. + for (ToolCall toolCall : responseMessage.toolCalls()) { + var functionName = toolCall.function().name(); + if ("getCurrentWeather".equals(functionName)) { + MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), + MockWeatherService.Request.class); + + MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + + // extend conversation with function response. + messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), + Role.TOOL, null, toolCall.id(), null)); + } + } + + var functionResponseRequest = new ChatCompletionRequest(messages, "gpt-4-1106-preview", 0.8f); + + ResponseEntity chatCompletion2 = completionApi + .chatCompletionEntity(functionResponseRequest); + + logger.info("Final response: " + chatCompletion2.getBody()); + + assertThat(chatCompletion2.getBody().choices()).isNotEmpty(); + + assertThat(chatCompletion2.getBody().choices().get(0).message().role()).isEqualTo(Role.ASSISTANT); + assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("San Francisco") + .containsAnyOf("30.0°F", "30°F"); + assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Tokyo") + .containsAnyOf("10.0°C", "10°C"); + ; + assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Paris") + .containsAnyOf("15.0°C", "15°C"); + ; + } + + } + + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java index 3e27fc4cade..292a7340292 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java @@ -16,7 +16,10 @@ package org.springframework.ai.chat; +import java.util.List; + import org.springframework.ai.model.ModelOptions; +import org.springframework.ai.model.ToolFunctionCallback; /** * The ChatOptions represent the common options, portable across different chat models. @@ -35,4 +38,12 @@ public interface ChatOptions extends ModelOptions { void setTopK(Integer topK); + default List getToolCallbacks() { + throw new UnsupportedOperationException("ToolCallbacks is not supported by this model"); + } + + default void setToolCallbacks(List toolCallbacks) { + throw new UnsupportedOperationException("ToolCallbacks is not supported by this model"); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java index c28630999dd..2ebbe42a155 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java @@ -82,4 +82,9 @@ public ChatResponseMetadata getMetadata() { return this.chatResponseMetadata; } + @Override + public String toString() { + return "ChatResponse [metadata=" + chatResponseMetadata + ", generations=" + generations + "]"; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/AbstractToolFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/AbstractToolFunctionCallback.java new file mode 100644 index 00000000000..b2187c87b23 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/AbstractToolFunctionCallback.java @@ -0,0 +1,194 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +import java.util.function.Function; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.springframework.util.Assert; + +/** + * Abstract implementation of the {@link ToolFunctionCallback} for interacting with the + * Model's function calling protocol and a {@link Function} wrapping the interaction with + * the 3rd party service/function. + * + * Implement the {@code O apply(I request) } method to implement the interaction with the + * 3rd party service/function. + * + * The {@link #responseConverter} function is responsible to convert the 3rd party + * function's output type into a string expected by the LLM model. + * + * @param the 3rd party service input type. + * @param the 3rd party service output type. + * @author Christian Tzolov + */ +public abstract class AbstractToolFunctionCallback implements Function, ToolFunctionCallback { + + private final String name; + + private final String description; + + private final Class inputType; + + private final String inputTypeSchema; + + private final ObjectMapper objectMapper; + + private final Function responseConverter; + + /** + * Constructs a new {@link AbstractToolFunctionCallback} with the given name, + * description, input type and object mapper. + * @param name Function name. Should be unique within the ChatClient's function + * registry. + * @param description Function description. Used as a "system prompt" by the model to + * decide if the function should be called. + * @param inputType Used to compute, the argument's JSON schema required by the + * Model's function calling protocol. + */ + public AbstractToolFunctionCallback(String name, String description, Class inputType) { + this(name, description, inputType, (response) -> response.toString()); + } + + /** + * Constructs a new {@link AbstractToolFunctionCallback} with the given name, + * description, input type and object mapper. + * @param name Function name. Should be unique within the ChatClient's function + * registry. + * @param description Function description. Used as a "system prompt" by the model to + * decide if the function should be called. + * @param inputType Used to compute, the argument's JSON schema required by the + * Model's function calling protocol. + * @param responseConverter Used to convert the function's output type to a string. + */ + public AbstractToolFunctionCallback(String name, String description, Class inputType, + Function responseConverter) { + this(name, description, inputType, responseConverter, + new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)); + } + + /** + * Constructs a new {@link AbstractToolFunctionCallback} with the given name, + * description, input type and default object mapper. + * @param name Function name. Should be unique within the ChatClient's function + * registry. + * @param description Function description. Used as a "system prompt" by the model to + * decide if the function should be called. + * @param inputType Used to compute, the argument's JSON schema required by the + * Model's function calling protocol. + * @param responseConverter Used to convert the function's output type to a string. + * @param objectMapper Used to convert the function's input and output types to and + * from JSON. + */ + public AbstractToolFunctionCallback(String name, String description, Class inputType, + Function responseConverter, ObjectMapper objectMapper) { + Assert.notNull(name, "Name must not be null"); + Assert.notNull(description, "Description must not be null"); + Assert.notNull(inputType, "InputType must not be null"); + Assert.notNull(responseConverter, "ResponseConverter must not be null"); + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.name = name; + this.description = description; + this.inputType = inputType; + this.inputTypeSchema = ModelOptionsUtils.getJsonSchema(inputType); + this.responseConverter = responseConverter; + this.objectMapper = objectMapper; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public String getDescription() { + return this.description; + } + + @Override + public String getInputTypeSchema() { + return this.inputTypeSchema; + } + + @Override + public String call(String functionArguments) { + + // Convert the tool calls JSON arguments into a Java function request object. + I request = fromJson(functionArguments, inputType); + + // extend conversation with function response. + return this.andThen(this.responseConverter).apply(request); + } + + /** + * Implements the interaction with the 3rd party service/function. + */ + abstract public O apply(I request); + + private T fromJson(String json, Class targetClass) { + try { + return this.objectMapper.readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((name == null) ? 0 : name.hashCode()); + result = prime * result + ((description == null) ? 0 : description.hashCode()); + result = prime * result + ((inputType == null) ? 0 : inputType.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + AbstractToolFunctionCallback other = (AbstractToolFunctionCallback) obj; + if (name == null) { + if (other.name != null) + return false; + } + else if (!name.equals(other.name)) + return false; + if (description == null) { + if (other.description != null) + return false; + } + else if (!description.equals(other.description)) + return false; + if (inputType == null) { + if (other.inputType != null) + return false; + } + else if (!inputType.equals(other.inputType)) + return false; + return true; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 6235e043927..5346a053a38 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonProperty; @@ -31,6 +32,13 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; +import com.github.victools.jsonschema.generator.OptionPreset; +import com.github.victools.jsonschema.generator.SchemaGenerator; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfig; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; +import com.github.victools.jsonschema.generator.SchemaVersion; +import com.github.victools.jsonschema.module.jackson.JacksonModule; +import com.github.victools.jsonschema.module.jackson.JacksonOption; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeanWrapperImpl; @@ -52,6 +60,8 @@ public final class ModelOptionsUtils { private static ConcurrentHashMap, List> REQUEST_FIELD_NAMES_PER_CLASS = new ConcurrentHashMap, List>(); + private static AtomicReference SCHEMA_GENERATOR_CACHE = new AtomicReference<>(); + private ModelOptionsUtils() { } @@ -106,12 +116,10 @@ public static T merge(Object source, Object target, Class clazz, List e.getValue() != null) .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); - if (!CollectionUtils.isEmpty(requestFieldNames)) { - targetMap = targetMap.entrySet() - .stream() - .filter(e -> requestFieldNames.contains(e.getKey())) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); - } + targetMap = targetMap.entrySet() + .stream() + .filter(e -> requestFieldNames.contains(e.getKey())) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); return ModelOptionsUtils.mapToClass(targetMap, clazz); } @@ -280,4 +288,27 @@ private static String toGetName(String name) { return "get" + name.substring(0, 1).toUpperCase() + name.substring(1); } + /** + * Generates JSON Schema (version 2020_12) for the given class. + * @param clazz the class to generate JSON Schema for. + * @return the generated JSON Schema as a String. + */ + public static String getJsonSchema(Class clazz) { + + if (SCHEMA_GENERATOR_CACHE.get() == null) { + + JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED); + + SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, + OptionPreset.PLAIN_JSON) + .with(jacksonModule); + + SchemaGeneratorConfig config = configBuilder.build(); + SchemaGenerator generator = new SchemaGenerator(config); + SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator); + } + + return SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz).toPrettyString(); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ToolFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ToolFunctionCallback.java new file mode 100644 index 00000000000..6a73c67986f --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ToolFunctionCallback.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * Represents a model function call handler. Implementations are registered with the + * Models and called on prompts that trigger the function call. + * + * @author Christian Tzolov + */ +public interface ToolFunctionCallback { + + /** + * @return Returns the Function name. Unique within the model. + */ + public String getName(); + + /** + * @return Returns the function description. This description is used by the model do + * decide if the function should be called or not. + */ + public String getDescription(); + + /** + * @return Returns the JSON schema of the function input type. + */ + public String getInputTypeSchema(); + + /** + * Called when a model detects and triggers a function call. The model is responsible + * to pass the function arguments in the pre-configured JSON schema format. + * @param functionInput JSON string with the function arguments to be passed to the + * function. The arguments are defined as JSON schema usually registered with the the + * model. + * @return String containing the function call response. + */ + public String call(String functionInput); + +} \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chatclient-function-call.png b/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chatclient-function-call.png new file mode 100644 index 00000000000..887c14f31dc Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chatclient-function-call.png differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-function-calling-flow.png b/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-function-calling-flow.png new file mode 100644 index 00000000000..ff33bd95cb5 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-function-calling-flow.png differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index fcfecda6cfe..a005ac0f6ee 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -14,6 +14,7 @@ *** xref:api/embeddings/onnx.adoc[] ** xref:api/chatclient.adoc[] *** xref:api/clients/openai-chat.adoc[] +**** xref:api/clients/functions/openai-chat-functions.adoc[] *** xref:api/clients/azure-openai-chat.adoc[] *** xref:api/clients/ollama-chat.adoc[] *** xref:api/bedrock.adoc[Amazon Bedrock Chat] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/functions/openai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/functions/openai-chat-functions.adoc new file mode 100644 index 00000000000..69bc22db20f --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/functions/openai-chat-functions.adoc @@ -0,0 +1,152 @@ += Function Calling + +You can register custom Java functions with the `OpenAiChatClient` and have the OpenAI model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This is a powerful technique to connect the LLM capabilities with external tools and APIs. +The models have been trained to detect when a function should to be called and to respond with JSON that adheres to the function signature. + +Note that the OpenAI API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. + +To register your custom function you need to specify a function `name`, function `description` that helps the model to understand when to call the function, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. +Then you can implement a function that takes the function call arguments from the model interacts with the external, 3rd party, services and returns the result back to the model. + +Spring AI offers a generic link:../../../spring-ai-core/src/main/java/org/springframework/ai/model/ToolFunctionCallback.java[ToolFunctionCallback.java] interface and the companion link:../../../spring-ai-core/src/main/java/org/springframework/ai/model/AbstractToolFunctionCallback.java[AbstractToolFunctionCallback.java] utility class to simplify the implementation and registration of Java callback functions. + +== Quick Start + +Lets create a chatbot that answer questions by calling external tools. +For example lets register a custom function that takes a location and returns the current weather in that location. +Question such as "What’s the weather like in Boston?" should trigger the model to call the function providing the location as an argument. +The function uses some weather service API and returns the weather response back to the model to complete the conversation. + +Let the link:../../../models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java[MockWeatherService.java] represent the 3-rd party weather service API: + +[source,java] +---- +public class MockWeatherService implements Function { + + public enum Unit { C, F } + public record Request(String location, Unit unit) {} + public record Response(double temp, Unit unit) {} + + public Response apply(Request request) { + return new Response("30", Unit.C); + } +} +---- + +Then extend link:../../../spring-ai-core/src/main/java/org/springframework/ai/model/AbstractToolFunctionCallback.java[AbstractToolFunctionCallback] to implement our weather function like this: + +[source,java] +---- +public class WeatherFunctionCallback + extends AbstractToolFunctionCallback { + + private final MockWeatherService weatherService = new MockWeatherService(); + + public WeatherFunctionCallback(String name, String description, Class inputType) { + super(name, // (1) + description, // (2) + inputType, // (3) + (response) -> "" + response.temp() + response.unit()); // (4) + } + + @Override + public Response apply(Request request) { + return this.weatherService.apply(request); + } +}; +---- + +The constructor takes a function name (1), description (2), input type signature (3) and a converter (4) to convert the `Response` into a text. +The Spring AI auto-generates the JSON Scheme for the `MockWeatherService.Request.class` signature. + +=== Registering Functions as Beans + +If you enable the link:../openai-chat.html#_openaichatclient_auto_configuration[OpenAiChatClient Auto-Configuration], the easiest way to register a function is to created it as a bean in the Spring context: + +[source,java,linenums] +---- +@Configuration +static class Config { + @Bean + public WeatherFunctionCallback weatherFunctionInfo() { + return new WeatherFunctionCallback( + "CurrentWeather", // (1) name + "Get the weather in location", // (2) description + MockWeatherService.Request.class); // (3) signature + } + ... +} +---- + +Now you can enable the `CurrentWeather` function in your prompt calls: + +[source,java] +---- +OpenAiChatClient chatClient = ... + +UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + +ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withEnabledFunction("CurrentWeather").build())); // (1) Enable the function + +logger.info("Response: {}", response); +---- + +NOTE: you must enable, explicitly, the functions to be used in the prompt request using the `OpenAiChatOptions.builder().withEnabledFunction(...)` method (1). + +Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and the final response will be something like this: + +---- +Here is the current weather for the requested cities: +- San Francisco, CA: 30.0°C +- Tokyo, Japan: 10.0°C +- Paris, France: 15.0°C +---- + +The [ToolCallWithPromptFunctionRegistrationIT.java] integration test provides a complete example of how to register a function with the `OpenAiChatClient` using the auto-configuration. + +=== Register/Call Functions with Prompt Options + +In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: + +[source,java] +---- +OpenAiChatClient chatClient = ... + +UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + +var promptOptions = OpenAiChatOptions.builder() + .withToolCallbacks(List.of(new WeatherFunctionCallback( + "CurrentWeather", + "Get the weather in location", + MockWeatherService.Request.class))) + .build(); + +ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), promptOptions)); + +logger.info("Response: {}", response); +---- + +NOTE: The in-prompt registered functions are enabled by default for the duration of this request. + +This approach allows to dynamically chose different functions to be called based on the user input. + +The [ToolCallWithPromptFunctionRegistrationIT.java] integration test provides a complete example of how to register a function with the `OpenAiChatClient` and use it in a prompt request. + +=== Function Calling Flow + +The following diagram illustrates the flow of the OpenAiChatClient Function Calling: + +image:openai-chatclient-function-call.png[Chat Client Function Calling Flow] + +== Appendices: + +=== OpenAI API Function Calling Flow + +The following diagram illustrates the flow of the OpenAI API https://platform.openai.com/docs/guides/function-calling[Function Calling]: + +image:openai-function-calling-flow.png[OpenAI API Function Calling Flow] + +[org.springframework.ai.openai.chat.api.tool.OpenAiApiToolFunctionCallTests] provides a complete example of how to call a function using the OpenAI API. +It is based on the https://platform.openai.com/docs/guides/function-calling/parallel-function-calling[OpenAI Function Calling tutorial]. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc index 619c1c6be3f..e0c3c16c2b9 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc @@ -104,6 +104,11 @@ ChatResponse response = chatClient.call( TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java[OpenAiChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. +=== Function Calling + +You can register custom Java functions with the OpenAiChatClient and have the OpenAI model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. This is a powerful technique to connect the LLM capabilities with external tools and APIs. +link:functions/openai-chat-functions.html[Read more about Function Calling]. + === Sample Controller (Auto-configuration) https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-openai-spring-boot-starter` to your pom (or gradle) dependencies. @@ -224,8 +229,11 @@ Flux streamResponse = openAiApi.chatCompletionStream( new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8f, true)); ---- -Check the link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java[OpenAiApiIT.java] integration test for more examples. - Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java[OpenAiApi.java]'s JavaDoc for further information. +==== OpenAiApi Samples +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java[OpenAiApiIT.java] test provides some general examples how to use the lightweight library. + +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java[OpenAiApiToolFunctionCallIT.java] test shows how to use the low-level API to call tool functions. +Based on the link:https://platform.openai.com/docs/guides/function-calling/parallel-function-calling[OpenAI Function Calling] tutorial. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 95abcc248cf..3bfffc53d91 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -16,8 +16,11 @@ package org.springframework.ai.autoconfigure.openai; +import java.util.List; + import org.springframework.ai.autoconfigure.NativeHints; import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.model.ToolFunctionCallback; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.ai.openai.OpenAiImageClient; @@ -31,6 +34,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; @@ -47,7 +51,8 @@ public class OpenAiAutoConfiguration { @Bean @ConditionalOnMissingBean public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties, - OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder) { + OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder, + List toolFunctionCallbacks) { String apiKey = StringUtils.hasText(chatProperties.getApiKey()) ? chatProperties.getApiKey() : commonProperties.getApiKey(); @@ -60,6 +65,10 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper var openAiApi = new OpenAiApi(baseUrl, apiKey, restClientBuilder); + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + chatProperties.getOptions().getToolCallbacks().addAll(toolFunctionCallbacks); + } + OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi, chatProperties.getOptions()); return openAiChatClient; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java new file mode 100644 index 00000000000..e40085651d0 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java @@ -0,0 +1,93 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.openai.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + c("metric"), + /** + * Fahrenheit. + */ + f("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.c); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithBeanFunctionRegistrationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithBeanFunctionRegistrationIT.java new file mode 100644 index 00000000000..9cd0a456dca --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithBeanFunctionRegistrationIT.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.openai.tool; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.autoconfigure.openai.tool.MockWeatherService.Request; +import org.springframework.ai.autoconfigure.openai.tool.MockWeatherService.Response; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.AbstractToolFunctionCallback; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +public class ToolCallWithBeanFunctionRegistrationIT { + + private final Logger logger = LoggerFactory.getLogger(ToolCallWithBeanFunctionRegistrationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=gpt-4-1106-preview").run(context -> { + + OpenAiChatClient chatClient = context.getBean(OpenAiChatClient.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withEnabledFunction("WeatherInfo").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30.0", "10.0", "15.0"); + + }); + } + + @Configuration + static class Config { + + @Bean + public WeatherFunctionCallback weatherFunctionInfo() { + return new WeatherFunctionCallback("WeatherInfo", "Get the weather in location", + MockWeatherService.Request.class); + } + + public static class WeatherFunctionCallback + extends AbstractToolFunctionCallback { + + public WeatherFunctionCallback(String name, String description, Class inputType) { + super(name, description, inputType, (response) -> "" + response.temp() + response.unit()); + } + + private final MockWeatherService weatherService = new MockWeatherService(); + + @Override + public Response apply(Request request) { + return weatherService.apply(request); + } + + }; + + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithPromptFunctionRegistrationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithPromptFunctionRegistrationIT.java new file mode 100644 index 00000000000..7f2ae3ef44e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithPromptFunctionRegistrationIT.java @@ -0,0 +1,79 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.openai.tool; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.AbstractToolFunctionCallback; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +public class ToolCallWithPromptFunctionRegistrationIT { + + private final Logger logger = LoggerFactory.getLogger(ToolCallWithPromptFunctionRegistrationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)); + + @Test + void functionCallTest() { + contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=gpt-4-1106-preview").run(context -> { + + OpenAiChatClient chatClient = context.getBean(OpenAiChatClient.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + var promptOptions = OpenAiChatOptions.builder() + .withToolCallbacks(List + .of(new AbstractToolFunctionCallback( + "CurrentWeatherService", "Get the weather in location", MockWeatherService.Request.class, + (response) -> "" + response.temp() + response.unit()) { + + private final MockWeatherService weatherService = new MockWeatherService(); + + @Override + public MockWeatherService.Response apply(MockWeatherService.Request request) { + return weatherService.apply(request); + } + })) + .build(); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30.0", "10.0", "15.0"); + }); + } + +} \ No newline at end of file