diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index e4e3af09e61..dc39910303a 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -16,21 +16,31 @@ package org.springframework.ai.ollama; import java.util.Base64; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; -import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; +import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall; +import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.metadata.OllamaUsage; import org.springframework.util.Assert; @@ -54,7 +64,7 @@ * @author luocongqiu * @since 0.8.0 */ -public class OllamaChatModel implements ChatModel { +public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel { /** * Low-level Ollama API library. @@ -71,6 +81,12 @@ public OllamaChatModel(OllamaApi chatApi) { } public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions) { + this(chatApi, defaultOptions, null); + } + + public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions, + FunctionCallbackContext functionCallbackContext) { + super(functionCallbackContext); Assert.notNull(chatApi, "OllamaApi must not be null"); Assert.notNull(defaultOptions, "DefaultOptions must not be null"); this.chatApi = chatApi; @@ -100,11 +116,32 @@ public ChatResponse call(Prompt prompt) { OllamaApi.ChatResponse response = this.chatApi.chat(ollamaChatRequest(prompt, false)); - var generator = new Generation(response.message().content()); + List toolCalls = response.message().toolCalls() == null ? List.of() + : response.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), + ModelOptionsUtils.toJsonString(toolCall.function().arguments()))) + .toList(); + + var assistantMessage = new AssistantMessage(response.message().content(), Map.of(), toolCalls); + + ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (response.promptEvalCount() != null && response.evalCount() != null) { - generator = generator.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null)); + generationMetadata = ChatGenerationMetadata.from("DONE", null); + } + + var generator = new Generation(assistantMessage, generationMetadata); + var chatResponse = new ChatResponse(List.of(generator), from(response)); + + if (isToolCall(chatResponse, Set.of("DONE"))) { + var toolCallConversation = handleToolCalls(prompt, chatResponse); + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.call(new Prompt(toolCallConversation, prompt.getOptions())); } - return new ChatResponse(List.of(generator), from(response)); + + return chatResponse; } public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { @@ -126,15 +163,39 @@ public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { @Override public Flux stream(Prompt prompt) { - Flux response = this.chatApi.streamingChat(ollamaChatRequest(prompt, true)); + Flux olamaResponse = this.chatApi.streamingChat(ollamaChatRequest(prompt, true)); + + Flux chatResponse = olamaResponse.map(chunk -> { + String content = (chunk.message() != null) ? chunk.message().content() : ""; + List toolCalls = chunk.message().toolCalls() == null ? List.of() + : chunk.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), + ModelOptionsUtils.toJsonString(toolCall.function().arguments()))) + .toList(); + + var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls); + + ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; + if (chunk.promptEvalCount() != null && chunk.evalCount() != null) { + generationMetadata = ChatGenerationMetadata.from("DONE", null); + } + + var generator = new Generation(assistantMessage, generationMetadata); + return new ChatResponse(List.of(generator), from(chunk)); + }); - return response.map(chunk -> { - Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content()) - : new Generation(""); - if (Boolean.TRUE.equals(chunk.done())) { - generation = generation.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null)); + return chatResponse.flatMap(response -> { + if (isToolCall(response, Set.of("DONE"))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + } + else { + return Flux.just(response); } - return new ChatResponse(List.of(generation), from(chunk)); }); } @@ -147,28 +208,61 @@ OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) { .stream() .filter(message -> message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.ASSISTANT - || message.getMessageType() == MessageType.SYSTEM) - .map(m -> { - var messageBuilder = OllamaApi.Message.builder(toRole(m)).withContent(m.getContent()); - if (m instanceof UserMessage userMessage) { + || message.getMessageType() == MessageType.SYSTEM || message.getMessageType() == MessageType.TOOL) + .map(message -> { + if (message instanceof UserMessage userMessage) { + var messageBuilder = OllamaApi.Message.builder(Role.USER).withContent(message.getContent()); if (!CollectionUtils.isEmpty(userMessage.getMedia())) { messageBuilder.withImages(userMessage.getMedia() .stream() .map(media -> this.fromMediaData(media.getData())) .toList()); } + return List.of(messageBuilder.build()); + } + else if (message instanceof SystemMessage systemMessage) { + return List + .of(OllamaApi.Message.builder(Role.SYSTEM).withContent(systemMessage.getContent()).build()); + } + else if (message instanceof AssistantMessage assistantMessage) { + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { + var function = new ToolCallFunction(toolCall.name(), + ModelOptionsUtils.jsonToMap(toolCall.arguments())); + return new ToolCall(function); + }).toList(); + } + return List.of(OllamaApi.Message.builder(Role.ASSISTANT) + .withContent(assistantMessage.getContent()) + .withToolCalls(toolCalls) + .build()); + } + else if (message instanceof ToolResponseMessage toolMessage) { + + List responseMessages = toolMessage.getResponses() + .stream() + .map(tr -> OllamaApi.Message.builder(Role.TOOL).withContent(tr.responseData()).build()) + .toList(); + + return responseMessages; } - return messageBuilder.build(); + throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); }) + .flatMap(List::stream) .toList(); + Set functionsForThisRequest = new HashSet<>(); + // runtime options OllamaOptions runtimeOptions = null; if (prompt.getOptions() != null) { runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, OllamaOptions.class); + functionsForThisRequest.addAll(this.handleFunctionCallbackConfigurations(runtimeOptions, IS_RUNTIME_CALL)); } + functionsForThisRequest.addAll(this.handleFunctionCallbackConfigurations(this.defaultOptions, IS_RUNTIME_CALL)); OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class); // Override the model. @@ -190,6 +284,11 @@ OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) { requestBuilder.withKeepAlive(mergedOptions.getKeepAlive()); } + // Add the enabled functions definitions to the request's tools parameter. + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest)); + } + return requestBuilder.build(); } @@ -206,18 +305,12 @@ else if (mediaData instanceof String text) { } - private OllamaApi.Message.Role toRole(Message message) { - - switch (message.getMessageType()) { - case USER: - return Role.USER; - case ASSISTANT: - return Role.ASSISTANT; - case SYSTEM: - return Role.SYSTEM; - default: - throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); - } + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { + var function = new ChatRequest.Tool.Function(functionCallback.getName(), functionCallback.getDescription(), + functionCallback.getInputTypeSchema()); + return new ChatRequest.Tool(function); + }).toList(); } @Override diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index 0b66d46bddf..7a09474572b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -23,14 +23,10 @@ import java.util.Objects; import java.util.function.Consumer; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.client.ClientHttpResponse; @@ -40,6 +36,13 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + /** * Java Client for the Ollama API. https://ollama.ai/ * @@ -356,7 +359,8 @@ public Flux generateStreaming(GenerateRequest completionReques public record Message( @JsonProperty("role") Role role, @JsonProperty("content") String content, - @JsonProperty("images") List images) { + @JsonProperty("images") List images, + @JsonProperty("tool_calls") List toolCalls) { /** * The role of the message in the conversation. @@ -374,8 +378,37 @@ public enum Role { /** * Assistant message type. Usually the response from the model. */ - @JsonProperty("assistant") ASSISTANT; + @JsonProperty("assistant") ASSISTANT, + /** + * Tool message. + */ + @JsonProperty("tool") TOOL; + + } + + /** + * The relevant tool call. + * + * @param id The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the + * Submit tool outputs to run endpoint. + * @param type The type of tool call the output is required for. For now, this is always function. + * @param function The function definition. + */ + @JsonInclude(Include.NON_NULL) + public record ToolCall( + @JsonProperty("function") ToolCallFunction function) { + } + /** + * The function definition. + * + * @param name The name of the function. + * @param arguments The arguments that the model expects you to pass to the function. + */ + @JsonInclude(Include.NON_NULL) + public record ToolCallFunction( + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments) { } public static Builder builder(Role role) { @@ -387,6 +420,7 @@ public static class Builder { private final Role role; private String content; private List images; + private List toolCalls; public Builder(Role role) { this.role = role; @@ -402,8 +436,13 @@ public Builder withImages(List images) { return this; } + public Builder withToolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + public Message build() { - return new Message(role, content, images); + return new Message(role, content, images, toolCalls); } } @@ -429,8 +468,68 @@ public record ChatRequest( @JsonProperty("stream") Boolean stream, @JsonProperty("format") String format, @JsonProperty("keep_alive") String keepAlive, - @JsonProperty("options") Map options) { + @JsonProperty("options") Map options, + @JsonProperty("tools") List tools) { + + + /** + * Represents a tool the model may call. Currently, only functions are supported as a tool. + * + * @param type The type of the tool. Currently, only 'function' is supported. + * @param function The function definition. + */ + @JsonInclude(Include.NON_NULL) + public record Tool( + @JsonProperty("type") Type type, + @JsonProperty("function") Function function) { + + /** + * Create a tool of type 'function' and the given function definition. + * @param function function definition. + */ + @ConstructorBinding + public Tool(Function function) { + this(Type.FUNCTION, function); + } + /** + * Create a tool of type 'function' and the given function definition. + */ + public enum Type { + /** + * Function tool type. + */ + @JsonProperty("function") FUNCTION + } + + /** + * Function definition. + * + * @param name The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes. + * @param description A description of what the function does, used by the model to choose when and how to call + * the function. + * @param parameters The parameters the functions accepts, described as a JSON Schema object. To describe a + * function that accepts no parameters, provide the value {"type": "object", "properties": {}}. + */ + public record Function( + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("parameters") Map parameters) { + + /** + * Create tool function definition. + * + * @param description tool function description. + * @param name tool function name. + * @param jsonSchema tool function schema as json. + */ + @ConstructorBinding + public Function(String description, String name, String jsonSchema) { + this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); + } + } + } + public static Builder builder(String model) { return new Builder(model); } @@ -443,6 +542,7 @@ public static class Builder { private String format; private String keepAlive; private Map options = Map.of(); + private List tools = List.of(); public Builder(String model) { Assert.notNull(model, "The model can not be null."); @@ -482,8 +582,13 @@ public Builder withOptions(OllamaOptions options) { return this; } + public Builder withTools(List tools) { + this.tools = tools; + return this; + } + public ChatRequest build() { - return new ChatRequest(model, messages, stream, format, keepAlive, options); + return new ChatRequest(model, messages, stream, format, keepAlive, options, tools); } } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 5847bc82cb3..49a6f443fb6 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -15,19 +15,26 @@ */ package org.springframework.ai.ollama.api; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.embedding.EmbeddingOptions; /** * Helper class for creating strongly-typed Ollama options. @@ -40,7 +47,7 @@ * @see Ollama Types */ @JsonInclude(Include.NON_NULL) -public class OllamaOptions implements ChatOptions, EmbeddingOptions { +public class OllamaOptions implements FunctionCallingOptions, ChatOptions, EmbeddingOptions { public static final String DEFAULT_MODEL = OllamaModel.MISTRAL.id(); @@ -248,6 +255,38 @@ public class OllamaOptions implements ChatOptions, EmbeddingOptions { */ @JsonProperty("keep_alive") private String keepAlive; + /** + * OpenAI Tool Function Callbacks to register with the ChatModel. + * For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution. + * For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions + * from the registry to be used by the ChatModel chat completion requests. + */ + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. + * Functions with those names must exist in the functionCallbacks registry. + * The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. + * If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. + */ + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); + + + public static OllamaOptions builder() { + return new OllamaOptions(); + } + + public OllamaOptions build() { + return this; + } + /** * @param model The ollama model names to use. See the {@link OllamaModel} for the common models. */ @@ -424,6 +463,22 @@ public OllamaOptions withStop(List stop) { return this; } + public OllamaOptions withFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + return this; + } + + public OllamaOptions withFunctions(Set functions) { + this.functions = functions; + return this; + } + + public OllamaOptions withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.functions.add(functionName); + return this; + } + public String getFormat() { return this.format; } @@ -680,19 +735,33 @@ public void setStop(List stop) { this.stop = stop; } + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + + } + + @Override + public Set getFunctions() { + return this.functions; + } + + @Override + public void setFunctions(Set functions) { + this.functions = functions; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. */ public Map toMap() { - try { - var json = new ObjectMapper().writeValueAsString(this); - return new ObjectMapper().readValue(json, new TypeReference>() { - }); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } + return ModelOptionsUtils.objectToMap(this); } /** @@ -753,10 +822,48 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .withMirostatTau(fromOptions.getMirostatTau()) .withMirostatEta(fromOptions.getMirostatEta()) .withPenalizeNewline(fromOptions.getPenalizeNewline()) - .withStop(fromOptions.getStop()); + .withStop(fromOptions.getStop()) + .withFunctions(fromOptions.getFunctions()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()); } + // @formatter:on + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + OllamaOptions that = (OllamaOptions) o; + return Objects.equals(model, that.model) && Objects.equals(format, that.format) + && Objects.equals(keepAlive, that.keepAlive) && Objects.equals(useNUMA, that.useNUMA) + && Objects.equals(numCtx, that.numCtx) && Objects.equals(numBatch, that.numBatch) + && Objects.equals(numGQA, that.numGQA) && Objects.equals(numGPU, that.numGPU) + && Objects.equals(mainGPU, that.mainGPU) && Objects.equals(lowVRAM, that.lowVRAM) + && Objects.equals(f16KV, that.f16KV) && Objects.equals(logitsAll, that.logitsAll) + && Objects.equals(vocabOnly, that.vocabOnly) && Objects.equals(useMMap, that.useMMap) + && Objects.equals(useMLock, that.useMLock) && Objects.equals(numThread, that.numThread) + && Objects.equals(numKeep, that.numKeep) && Objects.equals(seed, that.seed) + && Objects.equals(numPredict, that.numPredict) && Objects.equals(topK, that.topK) + && Objects.equals(topP, that.topP) && Objects.equals(tfsZ, that.tfsZ) + && Objects.equals(typicalP, that.typicalP) && Objects.equals(repeatLastN, that.repeatLastN) + && Objects.equals(temperature, that.temperature) && Objects.equals(repeatPenalty, that.repeatPenalty) + && Objects.equals(presencePenalty, that.presencePenalty) + && Objects.equals(frequencyPenalty, that.frequencyPenalty) && Objects.equals(mirostat, that.mirostat) + && Objects.equals(mirostatTau, that.mirostatTau) && Objects.equals(mirostatEta, that.mirostatEta) + && Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop) + && Objects.equals(functionCallbacks, that.functionCallbacks) + && Objects.equals(functions, that.functions); + } - // @formatter:on + @Override + public int hashCode() { + return Objects.hash(this.model, this.format, this.keepAlive, this.useNUMA, this.numCtx, this.numBatch, + this.numGQA, numGPU, mainGPU, lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, this.useMMap, + this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, this.topP, tfsZ, + this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, + this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline, + this.stop, this.functionCallbacks, this.functions); + } } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java new file mode 100644 index 00000000000..6da2d04a000 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -0,0 +1,149 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.ollama; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.api.tool.MockWeatherService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; + +import reactor.core.publisher.Flux; + +@Disabled("For manual smoke testing only.") +@Testcontainers +@SpringBootTest(classes = OllamaChatModelFunctionCallingIT.Config.class) +class OllamaChatModelFunctionCallingIT { + + private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class); + + private static String MODEL = "mistral"; + + @Container + static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.2.8"); + + static String baseUrl = "http://localhost:11434"; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", MODEL); + logger.info(MODEL + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + @Autowired + ChatModel chatModel; + + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OllamaOptions.builder() + .withModel(MODEL) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + + @Disabled("Ollama API does not support streaming function calls yet") + @Test + void streamFunctionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OllamaOptions.builder() + .withModel(MODEL) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("30", "10", "15"); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OllamaApi ollamaApi() { + return new OllamaApi(baseUrl); + } + + @Bean + public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { + return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f)); + } + + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 698c391d11b..4f737e6b29b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -15,6 +15,8 @@ */ package org.springframework.ai.ollama; +import static org.assertj.core.api.Assertions.assertThat; + import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -26,16 +28,12 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; - -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; @@ -50,8 +48,9 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.core.convert.support.DefaultConversionService; - -import static org.assertj.core.api.Assertions.assertThat; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; @SpringBootTest @Testcontainers @@ -65,7 +64,7 @@ class OllamaChatModelIT { @Container static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.1.32"); - static String baseUrl; + static String baseUrl = "http://localhost:11434"; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 4a967dfa3ea..a767927d098 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -53,7 +53,7 @@ class OllamaChatModelMultimodalIT { @Container static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.1.32"); - static String baseUrl; + static String baseUrl = "http://localhost:11434"; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java new file mode 100644 index 00000000000..36b04d97f67 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java @@ -0,0 +1,90 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.ollama.api.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java new file mode 100644 index 00000000000..813cbe26735 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java @@ -0,0 +1,155 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama.api.tool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApi.ChatResponse; +import org.springframework.ai.ollama.api.OllamaApi.Message; +import org.springframework.ai.ollama.api.OllamaApi.Message.Role; +import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; + +/** + * @author Christian Tzolov + */ +@Disabled("For manual smoke testing only.") +@Testcontainers +public class OllamaApiToolFunctionCallIT { + + private static String MODEL = "mistral"; + + private static final Logger logger = LoggerFactory.getLogger(OllamaApiToolFunctionCallIT.class); + + MockWeatherService weatherService = new MockWeatherService(); + + @Container + static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.2.8"); + + static String baseUrl = "http://localhost:11434"; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", MODEL); + logger.info(MODEL + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + @SuppressWarnings("null") + @Test + public void toolFunctionCall() { + + OllamaApi completionApi = new OllamaApi(baseUrl); + + // Step 1: send the conversation and available functions to the model + var message = Message.builder(Role.USER) + // .withContent("What's the weather like in San Francisco, Tokyo, and Paris? + // Perform multiple function calls for each location.") + .withContent("What's the weather like in San Francisco, Tokyo, and Paris?") + .build(); + + var functionTool = new OllamaApi.ChatRequest.Tool(new OllamaApi.ChatRequest.Tool.Function("getCurrentWeather", + "Get the weather in location. Return temperature in Celsius.", ModelOptionsUtils.jsonToMap(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "unit"] + } + """))); + + List messages = new ArrayList<>(List.of(message)); + + OllamaApi.ChatRequest chatCompletionRequest = OllamaApi.ChatRequest.builder(MODEL) + .withMessages(messages) + .withTools(List.of(functionTool)) + .build(); + + ChatResponse chatCompletion = completionApi.chat(chatCompletionRequest); + + assertThat(chatCompletion).isNotNull(); + assertThat(chatCompletion.message()).isNotNull(); + + Message responseMessage = chatCompletion.message(); + + assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT); + assertThat(responseMessage.toolCalls()).isNotNull(); + + // Check if the model wanted to call a function + if (responseMessage.toolCalls() != null) { + + // extend conversation with assistant's reply. + messages.add(responseMessage); + + // Send the info for each function call and function response to the model. + for (ToolCall toolCall : responseMessage.toolCalls()) { + var functionName = toolCall.function().name(); + if ("getCurrentWeather".equals(functionName)) { + Map responseMap = toolCall.function().arguments(); + MockWeatherService.Request weatherRequest = ModelOptionsUtils.mapToClass(responseMap, + MockWeatherService.Request.class); + + MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + + // extend conversation with function response. + messages.add(Message.builder(Role.TOOL) + .withContent("" + weatherResponse.temp() + weatherRequest.unit()) + .build()); + } + } + + var functionResponseRequest = OllamaApi.ChatRequest.builder(MODEL).withMessages(messages).build(); + + ChatResponse chatCompletion2 = completionApi.chat(functionResponseRequest); + + logger.info("Final response: " + chatCompletion2); + + assertThat(chatCompletion2).isNotNull(); + + assertThat(chatCompletion2.message().role()).isEqualTo(Role.ASSISTANT); + assertThat(chatCompletion2.message().content()).contains("San Francisco").contains("30"); + assertThat(chatCompletion2.message().content()).contains("Tokyo").contains("10"); + assertThat(chatCompletion2.message().content()).contains("Paris").contains("15"); + } + + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml index c19018cc4d9..60bcffd32ff 100644 --- a/models/spring-ai-openai/pom.xml +++ b/models/spring-ai-openai/pom.xml @@ -106,7 +106,12 @@ test - + + org.testcontainers + ollama + test + + diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 759f7e1b472..cf6a497417b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -15,16 +15,24 @@ */ package org.springframework.ai.openai; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -32,7 +40,6 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; @@ -50,17 +57,10 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; + import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; - /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} * backed by {@link OpenAiApi}. @@ -175,7 +175,7 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), rateLimit)); - if (isToolCall(chatResponse, OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name())) { + if (isToolCall(chatResponse, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -245,7 +245,8 @@ public Flux stream(Prompt prompt) { })); return chatResponse.flatMap(response -> { - if (isToolCall(response, OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name())) { + + if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. @@ -285,20 +286,6 @@ private static Generation buildGeneration(Choice choice, Map met return generation; } - private List handleToolCalls(Prompt prompt, ChatResponse response) { - AssistantMessage assistantMessage = response.getResult().getOutput(); - ToolResponseMessage toolMessageResponse = this.executeFuncitons(assistantMessage); - return this.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); - } - - private List buildToolCallConversation(List previousMessages, AssistantMessage assistantMessage, - ToolResponseMessage toolResponseMessage) { - List messages = new ArrayList<>(previousMessages); - messages.add(assistantMessage); - messages.add(toolResponseMessage); - return messages; - } - /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java new file mode 100644 index 00000000000..c43b0364faa --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java @@ -0,0 +1,409 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.chat; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.net.URL; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeTypeUtils; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; + +import reactor.core.publisher.Flux; + +@Disabled("For manual smoke testing only.") +@Testcontainers +@SpringBootTest(classes = OllamaWithOpenAiChatModelIT.Config.class) +class OllamaWithOpenAiChatModelIT { + + private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class); + + private static final String DEFAULT_OLLAMA_MODEL = "mistral"; + + @Container + static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.2.8"); + + static String baseUrl = "http://localhost:11434"; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the '" + DEFAULT_OLLAMA_MODEL + " ' generative ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", DEFAULT_OLLAMA_MODEL); + ollamaContainer.execInContainer("ollama", "pull", "llava"); + logger.info(DEFAULT_OLLAMA_MODEL + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Autowired + private OpenAiChatModel chatModel; + + @Test + void roleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); + } + + @Test + void streamRoleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Flux flux = chatModel.stream(prompt); + + List responses = flux.collectList().block(); + assertThat(responses.size()).isGreaterThan(1); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + } + + @Test + @Disabled("Not supported by the current Ollama API") + void streamingWithTokenUsage() { + var promptOptions = OpenAiChatOptions.builder().withStreamUsage(true).withSeed(1).build(); + + var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions); + + var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage(); + var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); + + assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); + + assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); + assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); + + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getContent()); + assertThat(list).hasSize(5); + + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + @Test + void beanOutputConverter() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography for a random actor. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + assertThat(actorsFilms.getActor()).isNotEmpty(); + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = chatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OpenAiChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + + @Disabled("Ollama API does not support streaming function calls yet") + @Test + void streamFunctionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OpenAiChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("30", "10", "15"); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "llava" }) + void multiModalityEmbeddedImage(String modelName) throws IOException { + + var imageData = new ClassPathResource("/test.png"); + + var userMessage = new UserMessage("Explain what do you see on this picture?", + List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); + + var response = chatModel + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); + + logger.info(response.getResult().getOutput().getContent()); + assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket"); + } + + @Disabled("Not supported by the current Ollama API") + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "llava" }) + void multiModalityImageUrl(String modelName) throws IOException { + + var userMessage = new UserMessage("Explain what do you see on this picture?", List + .of(new Media(MimeTypeUtils.IMAGE_PNG, + new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png")))); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); + + logger.info(response.getResult().getOutput().getContent()); + assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket"); + } + + @Disabled("Not supported by the current Ollama API") + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "llava" }) + void streamingMultiModalityImageUrl(String modelName) throws IOException { + + var userMessage = new UserMessage("Explain what do you see on this picture?", List + .of(new Media(MimeTypeUtils.IMAGE_PNG, + new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png")))); + + Flux response = chatModel + .stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + assertThat(content).contains("bananas", "apple"); + assertThat(content).containsAnyOf("bowl", "basket"); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "mistral" }) + void validateCallResponseMetadata(String model) { + // @formatter:off + ChatResponse response = ChatClient.create(chatModel).prompt() + .options(OpenAiChatOptions.builder().withModel(model).build()) + .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") + .call() + .chatResponse(); + // @formatter:on + + logger.info(response.toString()); + assertThat(response.getMetadata().getId()).isNotEmpty(); + assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiApi chatCompletionApi() { + return new OpenAiApi(baseUrl, ""); + } + + @Bean + public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { + return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().withModel(DEFAULT_OLLAMA_MODEL).build()); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index 6cae803d561..46307dfde75 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -23,7 +23,9 @@ import java.util.concurrent.ConcurrentHashMap; import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -95,6 +97,20 @@ protected Set handleFunctionCallbackConfigurations(FunctionCallingOption return functionToCall; } + protected List handleToolCalls(Prompt prompt, ChatResponse response) { + AssistantMessage assistantMessage = response.getResult().getOutput(); + ToolResponseMessage toolMessageResponse = this.executeFuncitons(assistantMessage); + return this.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); + } + + protected List buildToolCallConversation(List previousMessages, AssistantMessage assistantMessage, + ToolResponseMessage toolResponseMessage) { + List messages = new ArrayList<>(previousMessages); + messages.add(assistantMessage); + messages.add(toolResponseMessage); + return messages; + } + /** * Resolve the function callbacks by name. Retrieve them from the registry or try to * resolve them from the Application Context. @@ -152,8 +168,8 @@ protected ToolResponseMessage executeFuncitons(AssistantMessage assistantMessage return new ToolResponseMessage(toolResponses, Map.of()); } - protected boolean isToolCall(ChatResponse chatResponse, String toolCallFinishReason) { - Assert.hasText(toolCallFinishReason, "toolCallFinishReason cannot be null or empty"); + protected boolean isToolCall(ChatResponse chatResponse, Set toolCallFinishReasons) { + Assert.isTrue(!CollectionUtils.isEmpty(toolCallFinishReasons), "Tool call finish reasons cannot be empty!"); if (chatResponse == null) { return false; @@ -165,8 +181,10 @@ protected boolean isToolCall(ChatResponse chatResponse, String toolCallFinishRea } var generation = generations.get(0); - return !CollectionUtils.isEmpty(generation.getOutput().getToolCalls()) - && toolCallFinishReason.equalsIgnoreCase(generation.getMetadata().getFinishReason()); + return !CollectionUtils.isEmpty(generation.getOutput().getToolCalls()) && toolCallFinishReasons.stream() + .map(s -> s.toLowerCase()) + .toList() + .contains(generation.getMetadata().getFinishReason().toLowerCase()); } } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-chat-completion-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-chat-completion-api.jpg index fa5597a4e8e..78855d9cf78 100644 Binary files a/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-chat-completion-api.jpg and b/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-chat-completion-api.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-chatmodel-function-call.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-chatmodel-function-call.jpg new file mode 100644 index 00000000000..da79cb0d639 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-chatmodel-function-call.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-function-calling-flow.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-function-calling-flow.jpg new file mode 100644 index 00000000000..ef57e0174ac Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/ollama-function-calling-flow.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 644615f74aa..0676c38d74c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -28,6 +28,7 @@ *** xref:api/chat/moonshot-chat.adoc[Moonshot AI] //// **** xref:api/chat/functions/moonshot-chat-functions.adoc[Function Calling] *** xref:api/chat/ollama-chat.adoc[Ollama] +**** xref:api/chat/functions/ollama-chat-functions.adoc[Function Calling] *** xref:api/chat/openai-chat.adoc[OpenAI] **** xref:api/chat/functions/openai-chat-functions.adoc[Function Calling] *** xref:api/chat/qianfan-chat.adoc[QianFan] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc new file mode 100644 index 00000000000..5511562f4cf --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc @@ -0,0 +1,222 @@ += Function Calling + +TIP: You need Ollama 0.2.8 or newer. + +NOTE: Currently, the Ollama API (0.2.8) does not support function calling in streaming mode. + +You can register custom Java functions with the `OllamaChatModel` and have the Ollama deployed model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This allows you to connect the LLM capabilities with external tools and APIs. +The Ollama models tagged with the `Tools` label are trained to detect when a function should be called and to respond with JSON that adheres to the function signature. + +The Ollama API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. + +Spring AI provides flexible and user-friendly ways to register and call custom functions. +In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. +The `description` helps the model to understand when to call the function. + +As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and responds with the result back to the model. +Your function can in turn invoke other 3rd party services to provide the results. + +Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatModel`. + +Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. +The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java[FunctionCallbackWrapper.java] utility class to simplify the implementation and registration of Java callback functions. + + +== How it works + +Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location. + +We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt. + +For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. +The AI Model invokes a client side function. +The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response. + +The model-client interaction is illustrated in the <> diagram. + +Spring AI greatly simplifies the code you need to write to support function invocation. +It brokers the function invocation conversation for you. +You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options. +You can also reference multiple function bean names in your prompt. + +== Quick Start + +Let's create a chatbot that answer questions by calling our own function. +To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location. + +When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. +This RPC-like data is passed as JSON. + +Our function calls some SaaS based weather service API and returns the weather response back to the model to complete the conversation. +In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations. + +The following `MockWeatherService.java` represents the weather service API: + +[source,java] +---- +public class MockWeatherService implements Function { + + public enum Unit { C, F } + public record Request(String location, Unit unit) {} + public record Response(double temp, Unit unit) {} + + public Response apply(Request request) { + return new Response(30.0, Unit.C); + } +} +---- + +=== Registering Functions as Beans + +With the link:../ollama-chat.html#_auto_configuration[OllamaChatModel Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context. + +We start with describing the most POJO friendly options. + +==== Plain Java Functions + +In this approach you define `@Beans` in your application context as you would any other Spring managed object. + +Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallbackWrapper` wrapper that adds the logic for it being invoked via the AI model. +The name of the `@Bean` is passed as a `ChatOption`. + + +[source,java] +---- +@Configuration +static class Config { + + @Bean + @Description("Get the weather in location") // function description + public Function weatherFunction1() { + return new MockWeatherService(); + } + ... +} +---- + +The `@Description` annotation is optional and provides a function description (2) that helps the model understand when to call the function. It is an important property to set to help the AI model determine what client side function to invoke. + +Another option to provide the description of the function is to use the `@JsonClassDescription` annotation on the `MockWeatherService.Request` to provide the function description: + +[source,java] +---- + +@Configuration +static class Config { + + @Bean + public Function currentWeather3() { // (1) bean name as function name. + return new MockWeatherService(); + } + ... +} + +@JsonClassDescription("Get the weather in location") // (2) function description +public record Request(String location, Unit unit) {} +---- + +It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java[FunctionCallbackWithPlainFunctionBeanIT.java] demonstrates this approach. + + +==== FunctionCallback Wrapper + +Another way to register a function is to create a `FunctionCallbackWrapper` wrapper like this: + +[source,java] +---- +@Configuration +static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeather") // (1) function name + .withDescription("Get the weather in location") // (2) function description + .build(); + } + ... +} +---- + +It wraps the 3rd party `MockWeatherService` function and registers it as a `CurrentWeather` function with the `OllamaChatModel`. +It also provides a description (2) and an optional response converter (3) to convert the response into a text as expected by the model. + +NOTE: By default, the response converter does a JSON serialization of the Response object. + +NOTE: The `FunctionCallbackWrapper` internally resolves the function call signature based on the `MockWeatherService.Request` class. + +=== Specifying functions in Chat Options + +To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests: + +[source,java] +---- +OllamaChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OllamaOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function + +logger.info("Response: {}", response); +---- + +// NOTE: You can can have multiple functions registered in your `ChatModel` but only those enabled in the prompt request will be considered for the function calling. + +Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and the final response will be something like this: + +---- +Here is the current weather for the requested cities: +- San Francisco, CA: 30.0°C +- Tokyo, Japan: 10.0°C +- Paris, France: 15.0°C +---- + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java[FunctionCallbackWrapperIT.java] test demo this approach. + + +=== Register/Call Functions with Prompt Options + +In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: + +[source,java] +---- +OllamaChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + +var promptOptions = OllamaOptions.builder() + .withFunctionCallbacks(List.of(new FunctionCallbackWrapper<>( + "CurrentWeather", // name + "Get the weather in location", // function description + new MockWeatherService()))) // function code + .build(); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +---- + +NOTE: The in-prompt registered functions are enabled by default for the duration of this request. + +This approach allows to dynamically chose different functions to be called based on the user input. + +The https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java[FunctionCallbackInPromptIT.java] integration test provides a complete example of how to register a function with the `OllamaChatModel` and use it in a prompt request. + +== Appendices: + +=== Spring AI Function Calling Flow [[spring-ai-function-calling-flow]] + +The following diagram illustrates the flow of the OllamaChatModel Function Calling: + +image:ollama-chatmodel-function-call.jpg[width=800, title="OllamaChatModel Function Calling Flow"] + +=== OllamaAPI Function Calling Flow + +The following diagram illustrates the flow of the Ollama API: + +image:ollama-function-calling-flow.jpg[title="Ollama API Function Calling Flow", width=800] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/chat/api/tool/OpenAiApiToolFunctionCallIT.java[OllamaApiToolFunctionCallIT.java] provides a complete example on how to use the Ollama API function calling. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index 14863eadca0..9629f895268 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -102,6 +102,7 @@ The remaining `options` properties are based on the link:https://github.com/olla | spring.ai.ollama.chat.options.mirostat-eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. | 0.1 | spring.ai.ollama.chat.options.penalize-newline | ??? | true | spring.ai.ollama.chat.options.stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. | - +| spring.ai.ollama.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - |==== TIP: All properties prefixed with `spring.ai.ollama.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. @@ -129,6 +130,16 @@ ChatResponse response = chatModel.call( TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java[OllamaOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. +== Function Calling + +You can register custom Java functions with the OllamaChatModel and have the Ollama model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This is a powerful technique to connect the LLM capabilities with external tools and APIs. +Read more about xref:api/chat/functions/ollama-chat-functions.adoc[Ollama Function Calling]. + +TIP: You need Ollama 0.2.8 or newer. + +NOTE: Currently, the Ollama API (0.2.8) does not support function calling in streaming mode. + == Multimodal Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java index f74aa257814..87e8da12512 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java @@ -15,6 +15,10 @@ */ package org.springframework.ai.autoconfigure.ollama; +import java.util.List; + +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.OllamaEmbeddingModel; import org.springframework.ai.ollama.api.OllamaApi; @@ -26,7 +30,9 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.util.CollectionUtils; import org.springframework.web.client.RestClient; /** @@ -59,8 +65,14 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails, RestClient @ConditionalOnMissingBean @ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties) { - return new OllamaChatModel(ollamaApi, properties.getOptions()); + public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties, + List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { + + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + properties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + } + + return new OllamaChatModel(ollamaApi, properties.getOptions(), functionCallbackContext); } @Bean @@ -87,4 +99,12 @@ public String getBaseUrl() { } + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java index 97e716eeffb..7ee593325ce 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java @@ -176,7 +176,7 @@ protected void containerIsStarted(InspectContainerResponse containerInfo) { } - static void createImage(GenericContainer container, String localImageName) { + public static void createImage(GenericContainer container, String localImageName) { DockerImageName dockerImageName = DockerImageName.parse(container.getDockerImageName()); if (!dockerImageName.equals(DockerImageName.parse(localImageName))) { DockerClient dockerClient = DockerClientFactory.instance().client(); @@ -192,7 +192,7 @@ static void createImage(GenericContainer container, String localImageName) { } } - static class OllamaDockerImageName { + public static class OllamaDockerImageName { private final String baseImage; @@ -203,7 +203,7 @@ static class OllamaDockerImageName { this.localImageName = localImageName; } - static DockerImageName image() { + public static DockerImageName image() { return new OllamaDockerImageName(OllamaImage.IMAGE, OLLAMA_WITH_MODEL).resolve(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java index 38f26bec756..16358f69e19 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java @@ -17,6 +17,6 @@ public class OllamaImage { - static final String IMAGE = "ollama/ollama:0.1.32"; + public static final String IMAGE = "ollama/ollama:0.2.8"; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java new file mode 100644 index 00000000000..41e2e7dd5d8 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java @@ -0,0 +1,139 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.ollama.tool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; +import org.springframework.ai.autoconfigure.ollama.OllamaImage; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; + +import reactor.core.publisher.Flux; + +@Disabled("For manual smoke testing only.") +@Testcontainers +public class FunctionCallbackInPromptIT { + + private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class); + + private static String MODEL_NAME = "mistral"; + + @Container + static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.IMAGE); + + static String baseUrl = "http://localhost:11434"; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); + logger.info(MODEL_NAME + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.ollama.baseUrl=" + baseUrl, + "spring.ai.ollama.chat.options.model=" + MODEL_NAME, + "spring.ai.ollama.chat.options.temperature=0.5", + "spring.ai.ollama.chat.options.topK=10") + // @formatter:on + .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); + + @Test + void functionCallTest() { + contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); + + var promptOptions = OllamaOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeatherService") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + }); + } + + @Disabled("Ollama API does not support streaming function calls yet") + @Test + void streamingFunctionCallTest() { + + contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + var promptOptions = OllamaOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeatherService") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + Flux response = chatModel.stream(new Prompt(List.of(userMessage), promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + }); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java new file mode 100644 index 00000000000..0d42dfdbb57 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java @@ -0,0 +1,143 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.ollama.tool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; +import org.springframework.ai.autoconfigure.ollama.OllamaImage; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; + +import reactor.core.publisher.Flux; + +@Disabled("For manual smoke testing only.") +@Testcontainers +public class FunctionCallbackWrapperIT { + + private static final Log logger = LogFactory.getLog(FunctionCallbackWrapperIT.class); + + private static String MODEL_NAME = "mistral"; + + @Container + static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.IMAGE); + + static String baseUrl = "http://localhost:11434"; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); + logger.info(MODEL_NAME + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.ollama.baseUrl=" + baseUrl, + "spring.ai.ollama.chat.options.model=" + MODEL_NAME, + "spring.ai.ollama.chat.options.temperature=0.5", + "spring.ai.ollama.chat.options.topK=10") + // @formatter:on + .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), OllamaOptions.builder().withFunction("WeatherInfo").build())); + + logger.info("Response: " + response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + }); + } + + @Disabled("Ollama API does not support streaming function calls yet") + @Test + void streamFunctionCallTest() { + contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'WeatherInfo'"); + + Flux response = chatModel + .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().withFunction("WeatherInfo").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: " + content); + + assertThat(content).contains("30", "10", "15"); + }); + } + + @Configuration + static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("WeatherInfo") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build(); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java new file mode 100644 index 00000000000..dc780891bf7 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.ollama.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * Mock 3rd party weather service. + * + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 10; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + +} \ No newline at end of file