diff --git a/README.md b/README.md index eff3dcf5fc8..467415dcaf2 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,15 @@ Let's make your `@Beans` intelligent! :warning: ### Breaking Changes + +December 20, 2023 Update + +Refactor the Ollama client and related classes and package names + + - Replace the org.springframework.ai.ollama.client.OllamaClient by org.springframework.ai.ollama.OllamaChatClient. + - The OllamaChatClient method signatures have changed. + - Rename the org.springframework.ai.autoconfigure.ollama.OllamaProperties into org.springframework.ai.autoconfigure.ollama.OllamaChatProperties and change the suffix to: `spring.ai.ollama.chat`. Some of the properties have changed as well. + December 19, 2023 Update Renaming of AiClient and related classes and packagenames @@ -22,7 +31,7 @@ Renaming of AiClient and related classes and packagenames * Rename AiStreamClient to StreamingChatClient * Rename package org.sf.ai.client to org.sf.ai.chat -Rename artifact ID of +Rename artifact ID of * `transformers-embedding` to `spring-ai-transformers` @@ -144,7 +153,7 @@ Following vector stores are supported: ```xml - org.springframework.ai + org.springframework.ai spring-ai-azure-vector-store-spring-boot-starter 0.8.0-SNAPSHOT @@ -154,7 +163,7 @@ Following vector stores are supported: ```xml - org.springframework.ai + org.springframework.ai spring-ai-chroma-store-spring-boot-starter 0.8.0-SNAPSHOT @@ -163,7 +172,7 @@ Following vector stores are supported: ```xml - org.springframework.ai + org.springframework.ai spring-ai-milvus-store-spring-boot-starter 0.8.0-SNAPSHOT @@ -173,7 +182,7 @@ Following vector stores are supported: ```xml - org.springframework.ai + org.springframework.ai spring-ai-pgvector-store-spring-boot-starter 0.8.0-SNAPSHOT @@ -182,7 +191,7 @@ Following vector stores are supported: ```xml - org.springframework.ai + org.springframework.ai spring-ai-pinecone-store-spring-boot-starter 0.8.0-SNAPSHOT @@ -191,7 +200,7 @@ Following vector stores are supported: ```xml - org.springframework.ai + org.springframework.ai spring-ai-weaviate-store-spring-boot-starter 0.8.0-SNAPSHOT @@ -200,7 +209,7 @@ Following vector stores are supported: ```xml - org.springframework.ai + org.springframework.ai spring-ai-neo4j-store-spring-boot-starter 0.8.0-SNAPSHOT diff --git a/models/spring-ai-ollama/README.md b/models/spring-ai-ollama/README.md new file mode 100644 index 00000000000..d87417ed6b0 --- /dev/null +++ b/models/spring-ai-ollama/README.md @@ -0,0 +1,120 @@ +# 1. Ollama Chat and Embedding + +## 1.1 OllamaApi + +[OllamaApi](./src/main/java/org/springframework/ai/ollama/api/OllamaApi.java) provides is lightweight Java client for [Ollama models](https://ollama.ai/). + +The OllamaApi provides the Chat completion as well as Embedding endpoints. + +Following class diagram illustrates the OllamaApi interface and building blocks for chat completion: + +![OllamaApi Class Diagram](./src/test/resources/doc/Ollama%20Chat%20API.jpg) + +The OllamaApi can supports all [Ollama Models](https://ollama.ai/library) providing synchronous chat completion, streaming chat completion and embedding: + +```java +ChatResponse chat(ChatRequest chatRequest) + +Flux streamingChat(ChatRequest chatRequest) + +EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) +``` + +> NOTE: OllamaApi expose also the Ollama `generation` endpoint but later if inferior compared to the Ollama `chat` endpoint. + +The `OllamaApiOptions` is helper class used as type-safe option builder. It provides `toMap` to convert the content into `Map`. + +Here is a simple snippet how to use the OllamaApi programmatically: + +```java +var request = ChatRequest.builder("orca-mini") + .withStream(false) + .withMessages(List.of(Message.builder(Role.user) + .withContent("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") + .build())) + .withOptions(Options.builder().withTemperature(0.9f).build()) + .build(); + +ChatResponse response = ollamaApi.chat(request); +``` + +```java +var request = ChatRequest.builder("orca-mini") + .withStream(true) + .withMessages(List.of(Message.builder(Role.user) + .withContent("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") + .build())) + .withOptions(Options.builder().withTemperature(0.9f).build().toMap()) + .build(); + +Flux response = ollamaApi.streamingChat(request); + +List responses = response.collectList().block(); +``` + +```java +EmbeddingRequest request = new EmbeddingRequest("orca-mini", "I like to eat apples"); + +EmbeddingResponse response = ollamaApi.embeddings(request); +``` + +## 1.2 OllamaChatClient and OllamaEmbeddingClient + +The [OllamaChatClient](./src/main/java/org/springframework/ai/ollama/OllamaChatClient.java) implements the Spring-Ai `ChatClient` and `StreamingChatClient` interfaces. + +The [OllamaEmbeddingClient](./src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java) implements the Spring AI `EmbeddingClient` interface. + +Both the OllamaChatClient and the OllamaEmbeddingClient leverage the `OllamaApi`. + +You can configure the clients like this: + +```java +@Bean +public OllamaApi ollamaApi() { + return new OllamaApi(baseUrl); +} + +@Bean +public OllamaChatClient ollamaChat(OllamaApi ollamaApi) { + return new OllamaChatClient(ollamaApi).withModel(MODEL) + .withOptions(OllamaApiOptions.Options.builder().withTemperature(0.9f).build()); +} + +@Bean +public OllamaEmbeddingClient ollamaEmbedding(OllamaApi ollamaApi) { + return new OllamaEmbeddingClient(ollamaApi).withModel("orca-mini"); +} + +``` + +or you can leverage the `spring-ai-ollama-spring-boot-starter` Spring Boot starter. +For this add the following dependency: + +```xml + + spring-ai-ollama-spring-boot-starter + org.springframework.ai + 0.8.0-SNAPSHOT + +``` + +Use the `OllamaChatProperties` to configure the Ollama Chat client: + +| Property | Description | Default | +| ------------- | ------------- | ------------- | +| spring.ai.ollama.chat.model | Model to use. | llama2 | +| spring.ai.ollama.chat.base-url | The base url of the Ollama server. | http://localhost:11434 | +| spring.ai.ollama.chat.enabled | Allows you to disable the Ollama Chat autoconfiguration. | true | +| spring.ai.ollama.chat.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.8 | +| spring.ai.ollama.chat.topP | The maximum cumulative probability of tokens to consider when sampling. | - | +| spring.ai.ollama.chat.topK | Max number or responses to generate. | - | +| spring.ai.options.chat.options (WIP) | A Map used to configure the Chat client. | - | + +and `OllamaEmbeddingProperties` to configure the Ollama Embedding client: + +| Property | Description | Default | +| ------------- | ------------- | ------------- | +| spring.ai.ollama.embedding.model | Model to use. | llama2 | +| spring.ai.ollama.embedding.base-url | The base url of the Ollama server. | http://localhost:11434 | +| spring.ai.ollama.embedding.enabled | Allows you to disable the Ollama embedding autoconfiguration. | true | +| spring.ai.options.embedding.options (WIP) | A Map used to configure the embedding client. | - | diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml index fead80fc019..0da9e9a5353 100644 --- a/models/spring-ai-ollama/pom.xml +++ b/models/spring-ai-ollama/pom.xml @@ -1,7 +1,6 @@ + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 org.springframework.ai @@ -28,6 +27,23 @@ ${project.parent.version} + + org.springframework + spring-webflux + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson.version} + + org.springframework.boot spring-boot-starter-logging @@ -39,5 +55,17 @@ spring-boot-starter-test test + + + org.springframework.boot + spring-boot-testcontainers + test + + + + org.testcontainers + junit-jupiter + test + \ No newline at end of file diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java new file mode 100644 index 00000000000..517734205f5 --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java @@ -0,0 +1,140 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama; + +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.ChatClient; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.metadata.ChoiceMetadata; +import org.springframework.ai.metadata.Usage; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; +import org.springframework.ai.ollama.api.OllamaApi.Message.Role; +import org.springframework.ai.ollama.api.OllamaApiOptions; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.prompt.messages.MessageType; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +public class OllamaChatClient implements ChatClient, StreamingChatClient { + + private final OllamaApi chatApi; + + private String model = "orca-mini"; + + private Map clientOptions; + + public OllamaChatClient(OllamaApi chatApi) { + this.chatApi = chatApi; + } + + public OllamaChatClient withModel(String model) { + this.model = model; + return this; + } + + public OllamaChatClient withOptions(Map options) { + this.clientOptions = options; + return this; + } + + public OllamaChatClient withOptions(OllamaApiOptions.Options options) { + this.clientOptions = options.toMap(); + return this; + } + + @Override + public ChatResponse generate(Prompt prompt) { + + OllamaApi.ChatResponse response = this.chatApi.chat(request(prompt, this.model, false)); + var generator = new Generation(response.message().content()); + if (response.promptEvalCount() != null && response.evalCount() != null) { + generator = generator.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(response))); + } + return new ChatResponse(List.of(new Generation(response.message().content()))); + } + + @Override + public Flux generateStream(Prompt prompt) { + + Flux response = this.chatApi.streamingChat(request(prompt, this.model, true)); + + return response.map(chunk -> { + Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content()) + : new Generation(""); + if (chunk.done()) { + generation = generation.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(chunk))); + } + return new ChatResponse(List.of(generation)); + }); + } + + private Usage extractUsage(OllamaApi.ChatResponse response) { + return new Usage() { + + @Override + public Long getPromptTokens() { + return response.promptEvalCount().longValue(); + } + + @Override + public Long getGenerationTokens() { + return response.evalCount().longValue(); + } + }; + } + + private OllamaApi.ChatRequest request(Prompt prompt, String model, boolean stream) { + + List ollamaMessages = prompt.getMessages() + .stream() + .filter(message -> message.getMessageType() == MessageType.USER + || message.getMessageType() == MessageType.ASSISTANT) + .map(m -> OllamaApi.Message.builder(toRole(m)).withContent(m.getContent()).build()) + .toList(); + + return ChatRequest.builder(model) + .withStream(stream) + .withMessages(ollamaMessages) + .withOptions(this.clientOptions) + .build(); + } + + private OllamaApi.Message.Role toRole(Message message) { + + switch (message.getMessageType()) { + case USER: + return Role.user; + case ASSISTANT: + return Role.assistant; + case SYSTEM: + return Role.system; + default: + throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); + } + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java new file mode 100644 index 00000000000..c9385ebb203 --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.AbstractEmbeddingClient; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApi.EmbeddingRequest; +import org.springframework.ai.ollama.api.OllamaApiOptions; +import org.springframework.util.Assert; + +/** + * @author Christian Tzolov + */ +public class OllamaEmbeddingClient extends AbstractEmbeddingClient { + + private final OllamaApi ollamaApi; + + private String model = "orca-mini"; + + private Map clientOptions; + + public OllamaEmbeddingClient(OllamaApi ollamaApi) { + this.ollamaApi = ollamaApi; + } + + public OllamaEmbeddingClient withModel(String model) { + this.model = model; + return this; + } + + public OllamaEmbeddingClient withOptions(Map options) { + this.clientOptions = options; + return this; + } + + public OllamaEmbeddingClient withOptions(OllamaApiOptions.Options options) { + this.clientOptions = options.toMap(); + return this; + } + + @Override + public List embed(String text) { + return this.embed(List.of(text)).iterator().next(); + } + + @Override + public List embed(Document document) { + return embed(document.getContent()); + } + + @Override + public List> embed(List texts) { + Assert.notEmpty(texts, "At least one text is required!"); + Assert.isTrue(texts.size() == 1, "Ollama Embedding does not support batch embedding!"); + + String inputContent = texts.iterator().next(); + + OllamaApi.EmbeddingResponse response = this.ollamaApi + .embeddings(new EmbeddingRequest(this.model, inputContent, this.clientOptions)); + + return List.of(response.embedding()); + } + + @Override + public EmbeddingResponse embedForResponse(List texts) { + var indexCounter = new AtomicInteger(0); + List embeddings = this.embed(texts) + .stream() + .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + .toList(); + return new EmbeddingResponse(embeddings); + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java new file mode 100644 index 00000000000..c102284a268 --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -0,0 +1,592 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama.api; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.ai.ollama.api.OllamaApiOptions.Options; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * Java Client for the Ollama API. https://ollama.ai/ + * + * @author Christian Tzolov + * @since 0.8.0 + */ +// @formatter:off +public class OllamaApi { + + private static final Log logger = LogFactory.getLog(OllamaApi.class); + + private final static String DEFAULT_BASE_URL = "http://localhost:11434"; + + private final ResponseErrorHandler responseErrorHandler; + + private final RestClient restClient; + + private final WebClient webClient; + + private static class OllamaResponseErrorHandler implements ResponseErrorHandler { + + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + int statusCode = response.getStatusCode().value(); + String statusText = response.getStatusText(); + String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8); + logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message)); + throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message)); + } + } + + } + + /** + * Default constructor that uses the default localhost url. + */ + public OllamaApi() { + this(DEFAULT_BASE_URL); + } + + /** + * Crate a new OllamaApi instance with the given base url. + * @param baseUrl The base url of the Ollama server. + */ + public OllamaApi(String baseUrl) { + this(baseUrl, RestClient.builder()); + } + + /** + * Crate a new OllamaApi instance with the given base url and + * {@link RestClient.Builder}. + * @param baseUrl The base url of the Ollama server. + * @param restClientBuilder The {@link RestClient.Builder} to use. + */ + public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder) { + + this.responseErrorHandler = new OllamaResponseErrorHandler(); + + Consumer defaultHeaders = headers -> { + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setAccept(List.of(MediaType.APPLICATION_JSON)); + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); + + this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); + } + + // -------------------------------------------------------------------------- + // Generate & Streaming Generate + // -------------------------------------------------------------------------- + /** + * The request object sent to the /generate endpoint. + * + * @param model (required) The model to use for completion. + * @param prompt (required) The prompt(s) to generate completions for. + * @param format (optional) The format to return the response in. Currently the only + * accepted value is "json". + * @param options (optional) additional model parameters listed in the documentation + * for the Modelfile such as temperature. + * @param system (optional) system prompt to (overrides what is defined in the Modelfile). + * @param template (optional) the full prompt or prompt template (overrides what is + * defined in the Modelfile). + * @param context the context parameter returned from a previous request to /generate, + * this can be used to keep a short conversational memory. + * @param stream (optional) if false the response will be returned as a single + * response object, rather than a stream of objects. + * @param raw (optional) if true no formatting will be applied to the prompt and no + * context will be returned. You may choose to use the raw parameter if you are + * specifying a full templated prompt in your request to the API, and are managing + * history yourself. + */ + @JsonInclude(Include.NON_NULL) + public record GenerateRequest( + @JsonProperty("model") String model, + @JsonProperty("prompt") String prompt, + @JsonProperty("format") String format, + @JsonProperty("options") Map options, + @JsonProperty("system") String system, + @JsonProperty("template") String template, + @JsonProperty("context") List context, + @JsonProperty("stream") Boolean stream, + @JsonProperty("raw") Boolean raw) { + + /** + * Short cut constructor to create a CompletionRequest without options. + * @param model The model used for completion. + * @param prompt The prompt(s) to generate completions for. + * @param stream Whether to stream the response. + */ + public GenerateRequest(String model, String prompt, Boolean stream) { + this(model, prompt, null, null, null, null, null, stream, null); + } + + /** + * Short cut constructor to create a CompletionRequest without options. + * @param model The model used for completion. + * @param prompt The prompt(s) to generate completions for. + * @param enableJsonFormat Whether to return the response in json format. + * @param stream Whether to stream the response. + */ + public GenerateRequest(String model, String prompt, boolean enableJsonFormat, Boolean stream) { + this(model, prompt, (enableJsonFormat) ? "json" : null, null, null, null, null, stream, null); + } + + /** + * Create a CompletionRequest builder. + * @param prompt The prompt(s) to generate completions for. + */ + public static Builder builder(String prompt) { + return new Builder(prompt); + } + + public static class Builder { + + private String model; + private final String prompt; + private String format; + private Map options; + private String system; + private String template; + private List context; + private Boolean stream; + private Boolean raw; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder withModel(String model) { + this.model = model; + return this; + } + + public Builder withFormat(String format) { + this.format = format; + return this; + } + + public Builder withOptions(Map options) { + this.options = options; + return this; + } + + public Builder withOptions(Options options) { + this.options = options.toMap(); + return this; + } + + public Builder withSystem(String system) { + this.system = system; + return this; + } + + public Builder withTemplate(String template) { + this.template = template; + return this; + } + + public Builder withContext(List context) { + this.context = context; + return this; + } + + public Builder withStream(Boolean stream) { + this.stream = stream; + return this; + } + + public Builder withRaw(Boolean raw) { + this.raw = raw; + return this; + } + + public GenerateRequest build() { + return new GenerateRequest(model, prompt, format, options, system, template, context, stream, raw); + } + + } + } + + /** + * The response object returned from the /generate endpoint. To calculate how fast the + * response is generated in tokens per second (token/s), divide eval_count / + * eval_duration. + * + * @param model The model used for completion. + * @param createdAt When the request was made. + * @param response The completion response. Empty if the response was streamed, if not + * streamed, this will contain the full response + * @param done Whether this is the final response. If true, this response may be + * followed by another response with the following, additional fields: context, + * prompt_eval_count, prompt_eval_duration, eval_count, eval_duration. + * @param context Encoding of the conversation used in this response, this can be sent + * in the next request to keep a conversational memory. + * @param totalDuration Time spent generating the response. + * @param loadDuration Time spent loading the model. + * @param promptEvalCount Number of times the prompt was evaluated. + * @param promptEvalDuration Time spent evaluating the prompt. + * @param evalCount Number of tokens in the response. + * @param evalDuration Time spent generating the response. + */ + @JsonInclude(Include.NON_NULL) + public record GenerateResponse( + @JsonProperty("model") String model, + @JsonProperty("created_at") Instant createdAt, + @JsonProperty("response") String response, + @JsonProperty("done") Boolean done, + @JsonProperty("context") List context, + @JsonProperty("total_duration") Duration totalDuration, + @JsonProperty("load_duration") Duration loadDuration, + @JsonProperty("prompt_eval_count") Integer promptEvalCount, + @JsonProperty("prompt_eval_duration") Duration promptEvalDuration, + @JsonProperty("eval_count") Integer evalCount, + @JsonProperty("eval_duration") Duration evalDuration) { + } + + /** + * Generate a completion for the given prompt. + * @param completionRequest Completion request. + * @return Completion response. + */ + public GenerateResponse generate(GenerateRequest completionRequest) { + Assert.notNull(completionRequest, "The request body can not be null."); + Assert.isTrue(completionRequest.stream() == false, "Stream mode must be disabled."); + + return this.restClient.post() + .uri("/api/generate") + .body(completionRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(GenerateResponse.class); + } + + /** + * Generate a streaming completion for the given prompt. + * @param completionRequest Completion request. The request must set the stream + * property to true. + * @return Completion response as a {@link Flux} stream. + */ + public Flux generateStreaming(GenerateRequest completionRequest) { + Assert.notNull(completionRequest, "The request body can not be null."); + Assert.isTrue(completionRequest.stream(), "Request must set the steam property to true."); + + return webClient.post() + .uri("/api/generate") + .body(Mono.just(completionRequest), GenerateRequest.class) + .retrieve() + .bodyToFlux(GenerateResponse.class) + .handle((data, sink) -> { + System.out.println(data); + sink.next(data); + }); + } + + // -------------------------------------------------------------------------- + // Chat & Streaming Chat + // -------------------------------------------------------------------------- + /** + * Chat message object. + * + * @param role The role of the message of type {@link Role}. + * @param content The content of the message. + * @param images The list of images to send with the message. + */ + @JsonInclude(Include.NON_NULL) + public record Message( + @JsonProperty("role") Role role, + @JsonProperty("content") String content, + @JsonProperty("images") List images) { + + /** + * The role of the message in the conversation. + */ + public enum Role { + + /** + * System message type used as instructions to the model. + */ + system, + /** + * User message type. + */ + user, + /** + * Assistant message type. Usually the response from the model. + */ + assistant; + + } + + public static Builder builder(Role role) { + return new Builder(role); + } + + public static class Builder { + + private final Role role; + private String content; + private List images; + + public Builder(Role role) { + this.role = role; + } + + public Builder withContent(String content) { + this.content = content; + return this; + } + + public Builder withImages(List images) { + this.images = images; + return this; + } + + public Message build() { + return new Message(role, content, images); + } + + } + } + + /** + * Chat request object. + * + * @param model The model to use for completion. + * @param messages The list of messages to chat with. + * @param stream Whether to stream the response. + * @param format The format to return the response in. Currently the only accepted + * value is "json". + * @param options Additional model parameters. You can use the {@link Options} builder + * to create the options then {@link Options#toMap()} to convert the options into a + * map. + */ + @JsonInclude(Include.NON_NULL) + public record ChatRequest( + @JsonProperty("model") String model, + @JsonProperty("messages") List messages, + @JsonProperty("stream") Boolean stream, + @JsonProperty("format") String format, + @JsonProperty("options") Map options) { + + public static Builder builder(String model) { + return new Builder(model); + } + + public static class Builder { + + private final String model; + private List messages = List.of(); + private boolean stream = false; + private String format; + private Map options = Map.of(); + + public Builder(String model) { + Assert.notNull(model, "The model can not be null."); + this.model = model; + } + + public Builder withMessages(List messages) { + this.messages = messages; + return this; + } + + public Builder withStream(boolean stream) { + this.stream = stream; + return this; + } + + public Builder withFormat(String format) { + this.format = format; + return this; + } + + public Builder withOptions(Map options) { + this.options = options; + return this; + } + + public Builder withOptions(Options options) { + this.options = options.toMap(); + return this; + } + + public ChatRequest build() { + return new ChatRequest(model, messages, stream, format, options); + } + + } + } + + /** + * Ollama chat response object. + * + * @param model The model name used for completion. + * @param createdAt When the request was made. + * @param message The response {@link Message} with {@link Message.Role#assistant}. + * @param done Whether this is the final response. For streaming response only the + * last message is marked as done. If true, this response may be followed by another + * response with the following, additional fields: context, prompt_eval_count, + * prompt_eval_duration, eval_count, eval_duration. + * @param totalDuration Time spent generating the response. + * @param loadDuration Time spent loading the model. + * @param promptEvalCount number of tokens in the prompt.(*) + * @param promptEvalDuration time spent evaluating the prompt. + * @param evalCount number of tokens in the response. + * @param evalDuration time spent generating the response. + * @see Chat + * Completion API + * @see Ollama + * Types + */ + @JsonInclude(Include.NON_NULL) + public record ChatResponse( + @JsonProperty("model") String model, + @JsonProperty("created_at") Instant createdAt, + @JsonProperty("message") Message message, + @JsonProperty("done") Boolean done, + @JsonProperty("total_duration") Duration totalDuration, + @JsonProperty("load_duration") Duration loadDuration, + @JsonProperty("prompt_eval_count") Integer promptEvalCount, + @JsonProperty("prompt_eval_duration") Duration promptEvalDuration, + @JsonProperty("eval_count") Integer evalCount, + @JsonProperty("eval_duration") Duration evalDuration) { + } + + /** + * Generate the next message in a chat with a provided model. + * + * This is a streaming endpoint (controlled by the 'stream' request property), so + * there will be a series of responses. The final response object will include + * statistics and additional data from the request. + * @param chatRequest Chat request. + * @return Chat response. + */ + public ChatResponse chat(ChatRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Stream mode must be disabled."); + + return this.restClient.post() + .uri("/api/chat") + .body(chatRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ChatResponse.class); + } + + /** + * Streaming response for the chat completion request. + * @param chatRequest Chat request. The request must set the stream property to true. + * @return Chat response as a {@link Flux} stream. + */ + public Flux streamingChat(ChatRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + + return webClient.post() + .uri("/api/chat") + .body(Mono.just(chatRequest), GenerateRequest.class) + .retrieve() + .bodyToFlux(ChatResponse.class) + .handle((data, sink) -> { + System.out.println(data); + sink.next(data); + }); + } + + // -------------------------------------------------------------------------- + // Embeddings + // -------------------------------------------------------------------------- + /** + * Generate embeddings from a model. + * + * @param model The name of model to generate embeddings from. + * @param prompt The text to generate embeddings for. + * @param options Additional model parameters listed in the documentation for the + * Modelfile such as temperature. + */ + @JsonInclude(Include.NON_NULL) + public record EmbeddingRequest( + @JsonProperty("model") String model, + @JsonProperty("prompt") String prompt, + @JsonProperty("options") Map options) { + + /** + * short cut constructor to create a EmbeddingRequest without options. + * @param model The name of model to generate embeddings from. + * @param prompt The text to generate embeddings for. + */ + public EmbeddingRequest(String model, String prompt) { + this(model, prompt, null); + } + } + + /** + * The response object returned from the /embedding endpoint. + * + * @param embedding The embedding generated from the model. + */ + @JsonInclude(Include.NON_NULL) + public record EmbeddingResponse( + @JsonProperty("embedding") List embedding) { + } + + /** + * Generate embeddings from a model. + * @param embeddingRequest Embedding request. + * @return Embedding response. + */ + public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) { + Assert.notNull(embeddingRequest, "The request body can not be null."); + + return this.restClient.post() + .uri("/api/embeddings") + .body(embeddingRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(EmbeddingResponse.class); + } + +} +// @formatter:on \ No newline at end of file diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiOptions.java new file mode 100644 index 00000000000..2f0beeba49d --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiOptions.java @@ -0,0 +1,395 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama.api; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Helper class for building strongly typed the Ollama request options. + * + * @author Christian Tzolov + * @since 0.8.0 + */ +// @formatter:on +public class OllamaApiOptions { + + /** + * Runner options which must be set when the model is loaded into memory. + * + * @param useNUMA Whether to use NUMA. + * @param numCtx Sets the size of the context window used to generate the next token. + * (Default: 2048) + * @param numBatch ??? + * @param numGQA The number of GQA groups in the transformer layer. Required for some + * models, for example it is 8 for llama2:70b. + * @param numGPU The number of layers to send to the GPU(s). On macOS it defaults to 1 + * to enable metal support, 0 to disable. + * @param mainGPU ??? + * @param lowVRAM ??? + * @param f16KV ??? + * @param logitsAll ??? + * @param vocabOnly ??? + * @param useMMap ??? + * @param useMLock ??? + * @param embeddingOnly ??? + * @param ropeFrequencyBase ??? + * @param ropeFrequencyScale ??? + * @param numThread Sets the number of threads to use during computation. By default, + * Ollama will detect this for optimal performance. It is recommended to set this + * value to the number of physical CPU cores your system has (as opposed to the + * logical number of cores). + * + * Options specified in GenerateRequest. + * @param mirostat Enable Mirostat sampling for controlling perplexity. (default: 0, 0 + * = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + * @param mirostatTau Influences how quickly the algorithm responds to feedback from + * the generated text. A lower learning rate will result in slower adjustments, while + * a higher learning rate will make the algorithm more responsive. (Default: 0.1). + * @param mirostatEta Controls the balance between coherence and diversity of the + * output. A lower value will result in more focused and coherent text. (Default: + * 5.0). + * @param numKeep Unknown. + * @param seed Sets the random number seed to use for generation. Setting this to a + * specific number will make the model generate the same text for the same prompt. + * (Default: 0) + * @param numPredict Maximum number of tokens to predict when generating text. + * (Default: 128, -1 = infinite generation, -2 = fill context) + * @param topK Reduces the probability of generating nonsense. A higher value (e.g. + * 100) will give more diverse answers, while a lower value (e.g. 10) will be more + * conservative. (Default: 40) + * @param topP Works together with top-k. A higher value (e.g., 0.95) will lead to + * more diverse text, while a lower value (e.g., 0.5) will generate more focused and + * conservative text. (Default: 0.9) + * @param tfsZ Tail free sampling is used to reduce the impact of less probable tokens + * from the output. A higher value (e.g., 2.0) will reduce the impact more, while a + * value of 1.0 disables this setting. (default: 1) + * @param typicalP Unknown. + * @param repeatLastN Sets how far back for the model to look back to prevent + * repetition. (Default: 64, 0 = disabled, -1 = num_ctx) + * @param temperature The temperature of the model. Increasing the temperature will + * make the model answer more creatively. (Default: 0.8) + * @param repeatPenalty Sets how strongly to penalize repetitions. A higher value + * (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., + * 0.9) will be more lenient. (Default: 1.1) + * @param presencePenalty Unknown. + * @param frequencyPenalty Unknown. + * @param penalizeNewline Unknown. + * @param stop Sets the stop sequences to use. When this pattern is encountered the + * LLM will stop generating text and return. Multiple stop patterns may be set by + * specifying multiple separate stop parameters in a modelfile. + * @see Ollama + * Valid Parameters and Values + * @see Ollama Go + * Types + */ + @JsonInclude(Include.NON_NULL) + public record Options( + // Runner options which must be set when the model is loaded into memory. + @JsonProperty("numa") Boolean useNUMA, @JsonProperty("num_ctx") Integer numCtx, + @JsonProperty("num_batch") Integer numBatch, @JsonProperty("num_gqa") Integer numGQA, + @JsonProperty("num_gpu") Integer numGPU, @JsonProperty("main_gpu") Integer mainGPU, + @JsonProperty("low_vram") Boolean lowVRAM, @JsonProperty("f16_kv") Boolean f16KV, + @JsonProperty("logits_all") Boolean logitsAll, @JsonProperty("vocab_only") Boolean vocabOnly, + @JsonProperty("use_mmap") Boolean useMMap, @JsonProperty("use_mlock") Boolean useMLock, + @JsonProperty("embedding_only") Boolean embeddingOnly, + @JsonProperty("rope_frequency_base") Float ropeFrequencyBase, + @JsonProperty("rope_frequency_scale") Float ropeFrequencyScale, + @JsonProperty("num_thread") Integer numThread, + + // Options specified in GenerateRequest. + @JsonProperty("num_keep") Integer numKeep, @JsonProperty("seed") Integer seed, + @JsonProperty("num_predict") Integer numPredict, @JsonProperty("top_k") Integer topK, + @JsonProperty("top_p") Float topP, @JsonProperty("tfs_z") Float tfsZ, + @JsonProperty("typical_p") Float typicalP, @JsonProperty("repeat_last_n") Integer repeatLastN, + @JsonProperty("temperature") Float temperature, @JsonProperty("repeat_penalty") Float repeatPenalty, + @JsonProperty("presence_penalty") Float presencePenalty, + @JsonProperty("frequency_penalty") Float frequencyPenalty, @JsonProperty("mirostat") Integer mirostat, + @JsonProperty("mirostat_tau") Float mirostatTau, @JsonProperty("mirostat_eta") Float mirostatEta, + @JsonProperty("penalize_newline") Boolean penalizeNewline, @JsonProperty("stop") String[] stop) { + + /** + * Convert the {@link Options} object to a {@link Map} of key/value pairs. + * @return The {@link Map} of key/value pairs. + */ + public Map toMap() { + try { + var json = new ObjectMapper().writeValueAsString(this); + return new ObjectMapper().readValue(json, new TypeReference>() { + }); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Boolean useNUMA; + + private Integer numCtx; + + private Integer numBatch; + + private Integer numGQA; + + private Integer numGPU; + + private Integer mainGPU; + + private Boolean lowVRAM; + + private Boolean f16KV; + + private Boolean logitsAll; + + private Boolean vocabOnly; + + private Boolean useMMap; + + private Boolean useMLock; + + private Boolean embeddingOnly; + + private Float ropeFrequencyBase; + + private Float ropeFrequencyScale; + + private Integer numThread; + + private Integer numKeep; + + private Integer seed; + + private Integer numPredict; + + private Integer topK; + + private Float topP; + + private Float tfsZ; + + private Float typicalP; + + private Integer repeatLastN; + + private Float temperature; + + private Float repeatPenalty; + + private Float presencePenalty; + + private Float frequencyPenalty; + + private Integer mirostat; + + private Float mirostatTau; + + private Float mirostatEta; + + private Boolean penalizeNewline; + + private String[] stop; + + public Builder withUseNUMA(Boolean useNUMA) { + this.useNUMA = useNUMA; + return this; + } + + public Builder withNumCtx(Integer numCtx) { + this.numCtx = numCtx; + return this; + } + + public Builder withNumBatch(Integer numBatch) { + this.numBatch = numBatch; + return this; + } + + public Builder withNumGQA(Integer numGQA) { + this.numGQA = numGQA; + return this; + } + + public Builder withNumGPU(Integer numGPU) { + this.numGPU = numGPU; + return this; + } + + public Builder withMainGPU(Integer mainGPU) { + this.mainGPU = mainGPU; + return this; + } + + public Builder withLowVRAM(Boolean lowVRAM) { + this.lowVRAM = lowVRAM; + return this; + } + + public Builder withF16KV(Boolean f16KV) { + this.f16KV = f16KV; + return this; + } + + public Builder withLogitsAll(Boolean logitsAll) { + this.logitsAll = logitsAll; + return this; + } + + public Builder withVocabOnly(Boolean vocabOnly) { + this.vocabOnly = vocabOnly; + return this; + } + + public Builder withUseMMap(Boolean useMMap) { + this.useMMap = useMMap; + return this; + } + + public Builder withUseMLock(Boolean useMLock) { + this.useMLock = useMLock; + return this; + } + + public Builder withEmbeddingOnly(Boolean embeddingOnly) { + this.embeddingOnly = embeddingOnly; + return this; + } + + public Builder withRopeFrequencyBase(Float ropeFrequencyBase) { + this.ropeFrequencyBase = ropeFrequencyBase; + return this; + } + + public Builder withRopeFrequencyScale(Float ropeFrequencyScale) { + this.ropeFrequencyScale = ropeFrequencyScale; + return this; + } + + public Builder withNumThread(Integer numThread) { + this.numThread = numThread; + return this; + } + + public Builder withNumKeep(Integer numKeep) { + this.numKeep = numKeep; + return this; + } + + public Builder withSeed(Integer seed) { + this.seed = seed; + return this; + } + + public Builder withNumPredict(Integer numPredict) { + this.numPredict = numPredict; + return this; + } + + public Builder withTopK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder withTopP(Float topP) { + this.topP = topP; + return this; + } + + public Builder withTfsZ(Float tfsZ) { + this.tfsZ = tfsZ; + return this; + } + + public Builder withTypicalP(Float typicalP) { + this.typicalP = typicalP; + return this; + } + + public Builder withRepeatLastN(Integer repeatLastN) { + this.repeatLastN = repeatLastN; + return this; + } + + public Builder withTemperature(Float temperature) { + this.temperature = temperature; + return this; + } + + public Builder withRepeatPenalty(Float repeatPenalty) { + this.repeatPenalty = repeatPenalty; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMirostat(Integer mirostat) { + this.mirostat = mirostat; + return this; + } + + public Builder withMirostatTau(Float mirostatTau) { + this.mirostatTau = mirostatTau; + return this; + } + + public Builder withMirostatEta(Float mirostatEta) { + this.mirostatEta = mirostatEta; + return this; + } + + public Builder withPenalizeNewline(Boolean penalizeNewline) { + this.penalizeNewline = penalizeNewline; + return this; + } + + public Builder withStop(String[] stop) { + this.stop = stop; + return this; + } + + public Options build() { + return new Options(useNUMA, numCtx, numBatch, numGQA, numGPU, mainGPU, lowVRAM, f16KV, logitsAll, + vocabOnly, useMMap, useMLock, embeddingOnly, ropeFrequencyBase, ropeFrequencyScale, numThread, + numKeep, seed, numPredict, topK, topP, tfsZ, typicalP, repeatLastN, temperature, repeatPenalty, + presencePenalty, frequencyPenalty, mirostat, mirostatTau, mirostatEta, penalizeNewline, stop); + } + + } + } + +} +// @formatter:on \ No newline at end of file diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/client/OllamaClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/client/OllamaClient.java deleted file mode 100644 index 11b10282ee6..00000000000 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/client/OllamaClient.java +++ /dev/null @@ -1,246 +0,0 @@ -package org.springframework.ai.ollama.client; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.ChatClient; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.util.CollectionUtils; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.*; -import java.util.function.Consumer; -import java.util.stream.Collectors; - -/** - * A client implementation for interacting with Ollama Service. This class acts as an - * interface between the application and the Ollama AI Service, handling request creation, - * communication, and response processing. - * - * @author nullptr - */ -public class OllamaClient implements ChatClient { - - /** Logger for logging the events and messages. */ - private static final Logger log = LoggerFactory.getLogger(OllamaClient.class); - - /** Mapper for JSON serialization and deserialization. */ - private static final ObjectMapper jsonMapper = new ObjectMapper(); - - /** HTTP client for making asynchronous calls to the Ollama Service. */ - private static final HttpClient httpClient = HttpClient.newBuilder().build(); - - /** Base URL of the Ollama Service. */ - private final String baseUrl; - - /** Name of the model to be used for the AI service. */ - private final String model; - - /** Optional callback to handle individual generation results. */ - private Consumer simpleCallback; - - /** - * Constructs an OllamaClient with the specified base URL and model. - * @param baseUrl Base URL of the Ollama Service. - * @param model Model specification for the AI service. - */ - public OllamaClient(String baseUrl, String model) { - this.baseUrl = baseUrl; - this.model = model; - } - - /** - * Constructs an OllamaClient with the specified base URL, model, and a callback. - * @param baseUrl Base URL of the Ollama Service. - * @param model Model specification for the AI service. - * @param simpleCallback Callback to handle individual generation results. - */ - public OllamaClient(String baseUrl, String model, Consumer simpleCallback) { - this(baseUrl, model); - this.simpleCallback = simpleCallback; - } - - @Override - public ChatResponse generate(Prompt prompt) { - validatePrompt(prompt); - - HttpRequest request = buildHttpRequest(prompt); - var response = sendRequest(request); - - List results = readGenerateResults(response.body()); - return getAiResponse(results); - } - - /** - * Validates the provided prompt. - * @param prompt The prompt to validate. - */ - protected void validatePrompt(Prompt prompt) { - if (CollectionUtils.isEmpty(prompt.getMessages())) { - throw new RuntimeException("The prompt message cannot be empty."); - } - - if (prompt.getMessages().size() > 1) { - log.warn("Only the first prompt message will be used; subsequent messages will be ignored."); - } - } - - /** - * Constructs an HTTP request for the provided prompt. - * @param prompt The prompt for which the request needs to be built. - * @return The constructed HttpRequest. - */ - protected HttpRequest buildHttpRequest(Prompt prompt) { - String requestBody = getGenerateRequestBody(prompt.getMessages().get(0).getContent()); - - // remove the suffix '/' if necessary - String url = !this.baseUrl.endsWith("/") ? this.baseUrl : this.baseUrl.substring(0, this.baseUrl.length() - 1); - - return HttpRequest.newBuilder() - .uri(URI.create("%s/api/generate".formatted(url))) - .POST(HttpRequest.BodyPublishers.ofString(requestBody)) - .timeout(Duration.ofMinutes(5L)) - .build(); - } - - /** - * Sends the constructed HttpRequest and retrieves the HttpResponse. - * @param request The HttpRequest to be sent. - * @return HttpResponse containing the response data. - */ - protected HttpResponse sendRequest(HttpRequest request) { - var response = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()).join(); - if (response.statusCode() != 200) { - throw new RuntimeException("Ollama call returned an unexpected status: " + response.statusCode()); - } - return response; - } - - /** - * Serializes the prompt into a request body for the Ollama API call. - * @param prompt The prompt to be serialized. - * @return Serialized request body as a String. - */ - private String getGenerateRequestBody(String prompt) { - var data = Map.of("model", model, "prompt", prompt); - try { - return jsonMapper.writeValueAsString(data); - } - catch (JsonProcessingException ex) { - throw new RuntimeException("Failed to serialize the prompt to JSON", ex); - } - - } - - /** - * Reads and processes the results from the InputStream provided by the Ollama - * Service. - * @param inputStream InputStream containing the results from the Ollama Service. - * @return List of OllamaGenerateResult. - */ - protected List readGenerateResults(InputStream inputStream) { - try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream))) { - var results = new ArrayList(); - String line; - while ((line = bufferedReader.readLine()) != null) { - processResponseLine(line, results); - } - return results; - } - catch (IOException e) { - throw new RuntimeException("Error parsing Ollama generation response.", e); - } - } - - /** - * Processes a single line from the Ollama response. - * @param line The line to be processed. - * @param results List to which parsed results will be added. - */ - protected void processResponseLine(String line, List results) { - if (line.isBlank()) - return; - - log.debug("Received ollama generate response: {}", line); - - OllamaGenerateResult result; - try { - result = jsonMapper.readValue(line, OllamaGenerateResult.class); - } - catch (IOException e) { - throw new RuntimeException("Error parsing response line from Ollama.", e); - } - - if (result.getModel() == null || result.getDone() == null) { - throw new IllegalStateException("Received invalid data from Ollama. Model = " + result.getModel() - + " , Done = " + result.getDone()); - - } - - if (simpleCallback != null) { - simpleCallback.accept(result); - } - - results.add(result); - } - - /** - * Converts the list of OllamaGenerateResult into a structured ChatResponse. - * @param results List of OllamaGenerateResult. - * @return Formulated ChatResponse. - */ - protected ChatResponse getAiResponse(List results) { - var ollamaResponse = results.stream() - .filter(Objects::nonNull) - .filter(it -> it.getResponse() != null && !it.getResponse().isBlank()) - .filter(it -> it.getDone() != null) - .map(OllamaGenerateResult::getResponse) - .collect(Collectors.joining("")); - - var generation = new Generation(ollamaResponse); - - // TODO investigate mapping of additional metadata/runtime info to the response. - return new ChatResponse(Collections.singletonList(generation)); - } - - /** - * @return Model name for the AI service. - */ - public String getModel() { - return model; - } - - /** - * @return Base URL of the Ollama Service. - */ - public String getBaseUrl() { - return baseUrl; - } - - /** - * @return Callback that handles individual generation results. - */ - public Consumer getSimpleCallback() { - return simpleCallback; - } - - /** - * Sets the callback that handles individual generation results. - * @param simpleCallback The callback to be set. - */ - public void setSimpleCallback(Consumer simpleCallback) { - this.simpleCallback = simpleCallback; - } - -} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/client/OllamaGenerateResult.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/client/OllamaGenerateResult.java deleted file mode 100644 index f6f997d64d2..00000000000 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/client/OllamaGenerateResult.java +++ /dev/null @@ -1,146 +0,0 @@ -package org.springframework.ai.ollama.client; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.List; - -/** - * Ollama generate a completion api response model - * - * @author nullptr - */ -@JsonIgnoreProperties(ignoreUnknown = true) -public class OllamaGenerateResult { - - @JsonProperty("model") - private String model; - - @JsonProperty("created_at") - private String createdAt; - - @JsonProperty("response") - private String response; - - @JsonProperty("done") - private Boolean done; - - @JsonProperty("context") - private List context; - - @JsonProperty("total_duration") - private Long totalDuration; - - @JsonProperty("load_duration") - private Long loadDuration; - - @JsonProperty("prompt_eval_count") - private Long promptEvalCount; - - @JsonProperty("prompt_eval_duration") - private Long promptEvalDuration; - - @JsonProperty("eval_count") - private Long evalCount; - - @JsonProperty("eval_duration") - private Long evalDuration; - - public String getModel() { - return model; - } - - public void setModel(String model) { - this.model = model; - } - - public String getCreatedAt() { - return createdAt; - } - - public void setCreatedAt(String createdAt) { - this.createdAt = createdAt; - } - - public String getResponse() { - return response; - } - - public void setResponse(String response) { - this.response = response; - } - - public Boolean getDone() { - return done; - } - - public void setDone(Boolean done) { - this.done = done; - } - - public List getContext() { - return context; - } - - public void setContext(List context) { - this.context = context; - } - - public Long getTotalDuration() { - return totalDuration; - } - - public void setTotalDuration(Long totalDuration) { - this.totalDuration = totalDuration; - } - - public Long getLoadDuration() { - return loadDuration; - } - - public void setLoadDuration(Long loadDuration) { - this.loadDuration = loadDuration; - } - - public Long getPromptEvalCount() { - return promptEvalCount; - } - - public void setPromptEvalCount(Long promptEvalCount) { - this.promptEvalCount = promptEvalCount; - } - - public Long getPromptEvalDuration() { - return promptEvalDuration; - } - - public void setPromptEvalDuration(Long promptEvalDuration) { - this.promptEvalDuration = promptEvalDuration; - } - - public Long getEvalCount() { - return evalCount; - } - - public void setEvalCount(Long evalCount) { - this.evalCount = evalCount; - } - - public Long getEvalDuration() { - return evalDuration; - } - - public void setEvalDuration(Long evalDuration) { - this.evalDuration = evalDuration; - } - - @Override - public String toString() { - return "OllamaGenerateResult{" + "model='" + model + '\'' + ", createdAt='" + createdAt + '\'' + ", response='" - + response + '\'' + ", done='" + done + '\'' + ", context=" + context + ", totalDuration=" - + totalDuration + ", loadDuration=" + loadDuration + ", promptEvalCount=" + promptEvalCount - + ", promptEvalDuration=" + promptEvalDuration + ", evalCount=" + evalCount + ", evalDuration=" - + evalDuration + '}'; - } - -} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java new file mode 100644 index 00000000000..c0120220735 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java @@ -0,0 +1,192 @@ +package org.springframework.ai.ollama; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApiOptions; +import org.springframework.ai.parser.BeanOutputParser; +import org.springframework.ai.parser.ListOutputParser; +import org.springframework.ai.parser.MapOutputParser; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.PromptTemplate; +import org.springframework.ai.prompt.SystemPromptTemplate; +import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +@Testcontainers +@Disabled("For manual smoke testing only.") +class OllamaChatClientIT { + + private static String MODEL = "mistral"; + + private static final Log logger = LogFactory.getLog(OllamaChatClientIT.class); + + @Container + static GenericContainer ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.15").withExposedPorts(11434); + + static String baseUrl; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the '" + MODEL + " ' model ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", MODEL); + logger.info(MODEL + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + @Autowired + private OllamaChatClient client; + + @Test + void roleTest() { + Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + UserMessage userMessage = new UserMessage("Tell me about 5 famous pirates from the Golden Age of Piracy."); + + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = client.generate(prompt); + assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + } + + @Disabled("TODO: Fix the parser instructions to return the correct format") + @Test + void outputParser() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputParser outputParser = new ListOutputParser(conversionService); + + String format = outputParser.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors.", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.client.generate(prompt).getGeneration(); + + List list = outputParser.parse(generation.getContent()); + assertThat(list).hasSize(5); + } + + @Disabled("TODO: Fix the parser instructions to return the correct format") + @Test + void mapOutputParser() { + MapOutputParser outputParser = new MapOutputParser(); + + String format = outputParser.getFormat(); + String template = """ + Remove Markdown code blocks from the output. + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + Generation generation = client.generate(prompt).getGeneration(); + + Map result = outputParser.parse(generation.getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Disabled("TODO: Fix the parser instructions to return the correct format") + @Test + void beanOutputParserRecords() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = client.generate(prompt).getGeneration(); + + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Disabled("TODO: Fix the parser instructions to return the correct format") + @Test + void beanStreamOutputParserRecords() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = client.generateStream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getGenerations) + .flatMap(List::stream) + .map(Generation::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); + + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public OllamaApi ollamaApi() { + return new OllamaApi(baseUrl); + } + + @Bean + public OllamaChatClient ollamaChat(OllamaApi ollamaApi) { + return new OllamaChatClient(ollamaApi).withModel(MODEL) + .withOptions(OllamaApiOptions.Options.builder().withTemperature(0.9f).build()); + } + + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java new file mode 100644 index 00000000000..095a8f6d433 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java @@ -0,0 +1,82 @@ +package org.springframework.ai.ollama; + +import java.io.IOException; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApiIT; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@SpringBootTest +@Disabled("For manual smoke testing only.") +@Testcontainers +class OllamaEmbeddingClientIT { + + private static final Log logger = LogFactory.getLog(OllamaApiIT.class); + + @Container + static GenericContainer ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.15").withExposedPorts(11434); + + static String baseUrl; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", "orca-mini"); + logger.info("orca-mini pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + @Autowired + private OllamaEmbeddingClient embeddingClient; + + @Test + void singleEmbedding() { + assertThat(embeddingClient).isNotNull(); + EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); + assertThat(embeddingResponse.getData()).hasSize(1); + assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingClient.dimensions()).isEqualTo(3200); + } + + @Test + void batchEmbedding() { + assertThatThrownBy( + () -> embeddingClient.embedForResponse(List.of("Hello World", "World is big and salvation is near"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Ollama Embedding does not support batch embedding!"); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public OllamaApi ollamaApi() { + return new OllamaApi(baseUrl); + } + + @Bean + public OllamaEmbeddingClient ollamaEmbedding(OllamaApi ollamaApi) { + return new OllamaEmbeddingClient(ollamaApi).withModel("orca-mini"); + } + + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java new file mode 100644 index 00000000000..ec60e8a0f82 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -0,0 +1,146 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama.api; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + +import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; +import org.springframework.ai.ollama.api.OllamaApi.ChatResponse; +import org.springframework.ai.ollama.api.OllamaApi.EmbeddingRequest; +import org.springframework.ai.ollama.api.OllamaApi.EmbeddingResponse; +import org.springframework.ai.ollama.api.OllamaApi.GenerateRequest; +import org.springframework.ai.ollama.api.OllamaApi.GenerateResponse; +import org.springframework.ai.ollama.api.OllamaApi.Message; +import org.springframework.ai.ollama.api.OllamaApi.Message.Role; +import org.springframework.ai.ollama.api.OllamaApiOptions.Options; + +import static org.assertj.core.api.Assertions.assertThat;; + +/** + * @author Christian Tzolov + */ +@Disabled("For manual smoke testing only.") +@Testcontainers +public class OllamaApiIT { + + private static final Log logger = LogFactory.getLog(OllamaApiIT.class); + + @Container + static GenericContainer ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.15").withExposedPorts(11434); + + static OllamaApi ollamaApi; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", "orca-mini"); + logger.info("orca-mini pulling competed!"); + + ollamaApi = new OllamaApi("http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434)); + } + + @Test + public void generation() { + + var request = GenerateRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withModel("orca-mini") + .withStream(false) + .build(); + + GenerateResponse response = ollamaApi.generate(request); + + System.out.println(response); + + assertThat(response).isNotNull(); + assertThat(response.model()).isEqualTo(response.model()); + assertThat(response.response()).contains("Sofia"); + } + + @Test + public void chat() { + + var request = ChatRequest.builder("orca-mini") + .withStream(false) + .withMessages(List.of(Message.builder(Role.user) + .withContent("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") + .build())) + .withOptions(Options.builder().withTemperature(0.9f).build()) + .build(); + + ChatResponse response = ollamaApi.chat(request); + + System.out.println(response); + + assertThat(response).isNotNull(); + assertThat(response.model()).isEqualTo(response.model()); + assertThat(response.done()).isTrue(); + assertThat(response.message().role()).isEqualTo(Role.assistant); + assertThat(response.message().content()).contains("Sofia"); + } + + @Test + public void streamingChat() { + + var request = ChatRequest.builder("orca-mini") + .withStream(true) + .withMessages(List.of(Message.builder(Role.user) + .withContent("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") + .build())) + .withOptions(Options.builder().withTemperature(0.9f).build().toMap()) + .build(); + + Flux response = ollamaApi.streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(response).isNotNull(); + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().content()) + .collect(Collectors.joining("\n"))).contains("Sofia"); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message()).isNull(); + assertThat(lastResponse.done()).isTrue(); + } + + @Test + public void embedText() { + + EmbeddingRequest request = new EmbeddingRequest("orca-mini", "I like to eat apples"); + + EmbeddingResponse response = ollamaApi.embeddings(request); + + assertThat(response).isNotNull(); + assertThat(response.embedding()).hasSize(3200); + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/client/OllamaClientTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/client/OllamaClientTests.java deleted file mode 100644 index 839c33cc83c..00000000000 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/client/OllamaClientTests.java +++ /dev/null @@ -1,43 +0,0 @@ -package org.springframework.ai.ollama.client; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.prompt.Prompt; -import org.springframework.util.CollectionUtils; - -import java.util.function.Consumer; - -public class OllamaClientTests { - - @Test - @Disabled("For manual smoke testing only.") - public void smokeTest() { - OllamaClient ollama2 = getOllamaClient(); - - Prompt prompt = new Prompt("Hello"); - ChatResponse chatResponse = ollama2.generate(prompt); - Assertions.assertNotNull(chatResponse); - Assertions.assertFalse(CollectionUtils.isEmpty(chatResponse.getGenerations())); - Assertions.assertNotNull(chatResponse.getGeneration()); - Assertions.assertNotNull(chatResponse.getGeneration().getContent()); - } - - private static OllamaClient getOllamaClient() { - Consumer ollamaGenerateResultConsumer = it -> { - if (it.getDone()) { - System.out.println(); - System.out.printf("Total duration: %dms%n", it.getTotalDuration() / 1000 / 1000); - System.out.printf("Prompt tokens: %d%n", it.getPromptEvalCount()); - System.out.printf("Completion tokens: %d%n", it.getEvalCount()); - } - else { - System.out.print(it.getResponse()); - } - }; - - return new OllamaClient("http://127.0.0.1:11434", "llama2", ollamaGenerateResultConsumer); - } - -} diff --git a/models/spring-ai-ollama/src/test/resources/doc/Ollama Chat API.jpg b/models/spring-ai-ollama/src/test/resources/doc/Ollama Chat API.jpg new file mode 100644 index 00000000000..b4fcdeecf87 Binary files /dev/null and b/models/spring-ai-ollama/src/test/resources/doc/Ollama Chat API.jpg differ diff --git a/pom.xml b/pom.xml index 510a29ef5b0..bf12bfe1c31 100644 --- a/pom.xml +++ b/pom.xml @@ -101,6 +101,7 @@ 0.6.1 4.31.1 2.22.0 + 2.16.0 3.0.0 diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/NativeHints.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/NativeHints.java index 0dcbb431145..c76f043daa9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/NativeHints.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/NativeHints.java @@ -11,6 +11,8 @@ import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApiOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.vertex.api.VertexAiApi; import org.springframework.aot.hint.MemberCategory; @@ -40,7 +42,7 @@ public class NativeHints implements RuntimeHintsRegistrar { public void registerHints(RuntimeHints hints, ClassLoader classLoader) { for (var h : Set.of(new BedrockAiHints(), new VertexAiHints(), new OpenAiHints(), new PdfReaderHints(), - new KnuddelsHints())) + new KnuddelsHints(), new OllamaHints())) h.registerHints(hints, classLoader); hints.resources().registerResource(new ClassPathResource("embedding/embedding-model-dimensions.properties")); @@ -71,6 +73,19 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { } + static class OllamaHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClasses(OllamaApi.class)) + hints.reflection().registerType(tr, mcs); + for (var tr : findJsonAnnotatedClasses(OllamaApiOptions.class)) + hints.reflection().registerType(tr, mcs); + } + + } + static class BedrockAiHints implements RuntimeHintsRegistrar { @Override diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java index 2d349ee8597..1c2bbcbbcd8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java @@ -48,7 +48,7 @@ public class BedrockTitanChatAutoConfiguration { @Bean @ConditionalOnMissingBean - public TitanChatBedrockApi cohereApi(AwsCredentialsProvider credentialsProvider, + public TitanChatBedrockApi titanApi(AwsCredentialsProvider credentialsProvider, BedrockTitanChatProperties properties, BedrockAwsConnectionProperties awsProperties) { return new TitanChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), @@ -56,7 +56,7 @@ public TitanChatBedrockApi cohereApi(AwsCredentialsProvider credentialsProvider, } @Bean - public BedrockTitanChatClient cohereChatClient(TitanChatBedrockApi titanChatApi, + public BedrockTitanChatClient titanChatClient(TitanChatBedrockApi titanChatApi, BedrockTitanChatProperties properties) { return new BedrockTitanChatClient(titanChatApi).withTemperature(properties.getTemperature()) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java index c7f290cb167..fe37ef15bdf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java @@ -57,7 +57,7 @@ public TitanEmbeddingBedrockApi titanApi(AwsCredentialsProvider credentialsProvi @Bean @ConditionalOnMissingBean - public BedrockTitanEmbeddingClient cohereEmbeddingClient(TitanEmbeddingBedrockApi titanEmbeddingApi, + public BedrockTitanEmbeddingClient titanEmbeddingClient(TitanEmbeddingBedrockApi titanEmbeddingApi, BedrockTitanEmbeddingProperties properties) { return new BedrockTitanEmbeddingClient(titanEmbeddingApi).withInputType(properties.getInputType()); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java deleted file mode 100644 index 6bbd53635da..00000000000 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.autoconfigure.ollama; - -import org.springframework.ai.ollama.client.OllamaClient; -import org.springframework.boot.autoconfigure.AutoConfiguration; -import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; -import org.springframework.boot.context.properties.EnableConfigurationProperties; -import org.springframework.context.annotation.Bean; - -@AutoConfiguration -@ConditionalOnClass(OllamaClient.class) -@EnableConfigurationProperties(OllamaProperties.class) -public class OllamaAutoConfiguration { - - @Bean - @ConditionalOnMissingBean - public OllamaClient ollamaClient(OllamaProperties properties) { - return new OllamaClient(properties.getBaseUrl(), properties.getModel()); - } - -} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfiguration.java new file mode 100644 index 00000000000..b2c19c7afcf --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfiguration.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.ollama; + +import org.springframework.ai.autoconfigure.NativeHints; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApiOptions; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ImportRuntimeHints; + +/** + * {@link AutoConfiguration Auto-configuration} for Ollama Chat Client. + * + * @author Christian Tzolov + * @since 0.8.0 + */ +@AutoConfiguration +@ConditionalOnClass(OllamaApi.class) +@EnableConfigurationProperties({ OllamaChatProperties.class }) +@ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) +@ImportRuntimeHints(NativeHints.class) +public class OllamaChatAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public OllamaApi ollamaApi(OllamaChatProperties properties) { + return new OllamaApi(properties.getBaseUrl()); + } + + @Bean + public OllamaChatClient ollamaChatClient(OllamaApi ollamaApi, OllamaChatProperties properties) { + + var optionsBuilder = OllamaApiOptions.Options.builder(); + + if (properties.getTemperature() != null) { + optionsBuilder.withTemperature(properties.getTemperature()); + } + if (properties.getTopK() != null) { + optionsBuilder.withTopK(properties.getTopK()); + } + if (properties.getTopP() != null) { + optionsBuilder.withTopP(properties.getTopP()); + } + + var options = optionsBuilder.build().toMap(); + + if (properties.getOptions() != null) { + options.putAll(properties.getOptions()); + } + + return new OllamaChatClient(ollamaApi).withModel(properties.getModel()).withOptions(options); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java new file mode 100644 index 00000000000..9be5f861eb8 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java @@ -0,0 +1,125 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama; + +import java.util.Map; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Ollama Chat autoconfiguration properties. + * + * @author Christian Tzolov + * @since 0.8.0 + */ +@ConfigurationProperties(OllamaChatProperties.CONFIG_PREFIX) +public class OllamaChatProperties { + + public static final String CONFIG_PREFIX = "spring.ai.ollama.chat"; + + /** + * Base URL where Ollama API server is running. + */ + private String baseUrl = "http://localhost:11434"; + + /** + * Enable Ollama Chat Client. True by default. + */ + private boolean enabled = true; + + /** + * Ollama Chat model name. Defaults to 'llama2'. + */ + private String model = "llama2"; + + /** + * (optional) Use a lower value to decrease randomness in the response. Defaults to + * 0.7. + */ + private Float temperature = 0.8f; + + /** + * (optional) The maximum cumulative probability of tokens to consider when sampling. + * The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the + * smallest set of tokens whose probability sum is at least topP. + */ + private Float topP; + + /** + * Max number or responses to generate. + */ + private Integer topK; + + private Map options; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + + public boolean isEnabled() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + public Float getTopP() { + return topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer maxTokens) { + this.topK = maxTokens; + } + + public void setOptions(Map options) { + this.options = options; + } + + public Map getOptions() { + return options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfiguration.java new file mode 100644 index 00000000000..a311a230015 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfiguration.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama; + +import org.springframework.ai.autoconfigure.NativeHints; +import org.springframework.ai.ollama.OllamaEmbeddingClient; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ImportRuntimeHints; + +/** + * {@link AutoConfiguration Auto-configuration} for Ollama Embedding Client. + * + * @author Christian Tzolov + * @since 0.8.0 + */ +@AutoConfiguration +@ConditionalOnClass(OllamaApi.class) +@EnableConfigurationProperties({ OllamaEmbeddingProperties.class }) +@ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) +@ImportRuntimeHints(NativeHints.class) +public class OllamaEmbeddingAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public OllamaApi ollamaApi(OllamaEmbeddingProperties properties) { + return new OllamaApi(properties.getBaseUrl()); + } + + @Bean + @ConditionalOnMissingBean + public OllamaEmbeddingClient ollamaEmbeddingClient(OllamaApi ollamaApi, OllamaEmbeddingProperties properties) { + + return new OllamaEmbeddingClient(ollamaApi).withModel(properties.getModel()) + .withOptions(properties.getOptions()); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java similarity index 59% rename from spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaProperties.java rename to spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java index c72fdb5b8ed..ea64f41bdd9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java @@ -16,12 +16,25 @@ package org.springframework.ai.autoconfigure.ollama; +import java.util.Map; + import org.springframework.boot.context.properties.ConfigurationProperties; -@ConfigurationProperties(OllamaProperties.CONFIG_PREFIX) -public class OllamaProperties { +/** + * Ollama Embedding autoconfiguration properties. + * + * @author Christian Tzolov + * @since 0.8.0 + */ +@ConfigurationProperties(OllamaEmbeddingProperties.CONFIG_PREFIX) +public class OllamaEmbeddingProperties { + + public static final String CONFIG_PREFIX = "spring.ai.ollama.embedding"; - public static final String CONFIG_PREFIX = "spring.ai.ollama"; + /** + * Enable Ollama Embedding Client. True by default. + */ + private boolean enabled = true; /** * Base URL where Ollama API server is running. @@ -29,16 +42,18 @@ public class OllamaProperties { private String baseUrl = "http://localhost:11434"; /** - * Language model to use. + * Ollama Embedding model name. Defaults to 'llama2'. */ private String model = "llama2"; - public String getBaseUrl() { - return baseUrl; + private Map options; + + public boolean isEnabled() { + return enabled; } - public void setBaseUrl(String baseUrl) { - this.baseUrl = baseUrl; + public void setEnabled(boolean enabled) { + this.enabled = enabled; } public String getModel() { @@ -49,4 +64,20 @@ public void setModel(String model) { this.model = model; } + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + + public void setOptions(Map clientOptions) { + this.options = clientOptions; + } + + public Map getOptions() { + return options; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index a67fc127a84..b19a1365ad7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1,4 +1,3 @@ -org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.pgvector.PgVectorStoreAutoConfiguration @@ -17,6 +16,9 @@ org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereEmbeddingAutoCo org.springframework.ai.autoconfigure.bedrock.anthropic.BedrockAnthropicChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanEmbeddingAutoConfiguration +org.springframework.ai.autoconfigure.ollama.OllamaChatAutoConfiguration +org.springframework.ai.autoconfigure.ollama.OllamaEmbeddingAutoConfiguration + diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java deleted file mode 100644 index 73f42221eae..00000000000 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.autoconfigure.ollama; - -import org.junit.jupiter.api.Test; - -import org.springframework.ai.ollama.client.OllamaClient; -import org.springframework.boot.autoconfigure.AutoConfigurations; -import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -import static org.assertj.core.api.Assertions.assertThat; - -public class OllamaAutoConfigurationTests { - - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); - - @Test - void defaults() { - contextRunner.run(context -> { - OllamaProperties properties = context.getBean(OllamaProperties.class); - assertThat(properties.getBaseUrl()).isEqualTo("http://localhost:11434"); - assertThat(properties.getModel()).isEqualTo("llama2"); - - OllamaClient client = context.getBean(OllamaClient.class); - assertThat(client.getBaseUrl()).isEqualTo("http://localhost:11434"); - assertThat(client.getModel()).isEqualTo("llama2"); - }); - } - - @Test - void overrideProperties() { - contextRunner - .withPropertyValues("spring.ai.ollama.base-url=http://localhost:8080", "spring.ai.ollama.model=myModel") - .run(context -> { - OllamaProperties properties = context.getBean(OllamaProperties.class); - assertThat(properties.getBaseUrl()).isEqualTo("http://localhost:8080"); - assertThat(properties.getModel()).isEqualTo("myModel"); - - OllamaClient client = context.getBean(OllamaClient.class); - assertThat(client.getBaseUrl()).isEqualTo("http://localhost:8080"); - assertThat(client.getModel()).isEqualTo("myModel"); - }); - } - - @Test - void customConfig() { - contextRunner.withUserConfiguration(CustomConfig.class).run(context -> { - OllamaClient ollamaClient = context.getBean(OllamaClient.class); - assertThat(ollamaClient.getBaseUrl()).isEqualTo("http://localhost:8080"); - assertThat(ollamaClient.getModel()).isEqualTo("myModel"); - }); - } - - @Configuration(proxyBeanMethods = false) - static class CustomConfig { - - @Bean - OllamaClient myClient() { - return new OllamaClient("http://localhost:8080", "myModel"); - } - - } - -} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java new file mode 100644 index 00000000000..f4387f1854f --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java @@ -0,0 +1,118 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.SystemPromptTemplate; +import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +@Disabled("For manual smoke testing only.") +@Testcontainers +public class OllamaChatAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(OllamaChatAutoConfigurationIT.class); + + private static String MODEL_NAME = "orca-mini"; + + @Container + static GenericContainer ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.16").withExposedPorts(11434); + + static String baseUrl; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + logger.info("Start pulling the '" + MODEL_NAME + " ' model ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); + logger.info(MODEL_NAME + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.ollama.chat.enabled=true", "spring.ai.ollama.chat.model=" + MODEL_NAME, + "spring.ai.ollama.chat.baseUrl=" + baseUrl, "spring.ai.ollama.chat.temperature=0.5", + "spring.ai.ollama.chat.topK=500") + .withConfiguration(AutoConfigurations.of(OllamaChatAutoConfiguration.class)); + + private final Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + private final UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + + @Test + public void chatCompletion() { + contextRunner.run(context -> { + OllamaChatClient chatClient = context.getBean(OllamaChatClient.class); + ChatResponse response = chatClient.generate(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + }); + } + + @Test + public void chatCompletionStreaming() { + contextRunner.run(context -> { + + OllamaChatClient chatClient = context.getBean(OllamaChatClient.class); + + Flux response = chatClient.generateStream(new Prompt(List.of(userMessage, systemMessage))); + + List responses = response.collectList().block(); + assertThat(responses.size()).isGreaterThan(1); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getGenerations) + .flatMap(List::stream) + .map(Generation::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java new file mode 100644 index 00000000000..0dc856d7542 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +public class OllamaChatAutoConfigurationTests { + + @Test + public void propertiesTest() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.ollama.chat.enabled=true", "spring.ai.ollama.chat.model=MODEL_XYZ", + "spring.ai.ollama.chat.temperature=0.55", "spring.ai.ollama.chat.topP=0.55", + "spring.ai.ollama.chat.topK=123") + .withConfiguration(AutoConfigurations.of(OllamaChatAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(OllamaChatProperties.class); + + assertThat(chatProperties.isEnabled()).isTrue(); + assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); + + assertThat(chatProperties.getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getTopP()).isEqualTo(0.55f); + + assertThat(chatProperties.getTopK()).isEqualTo(123); + }); + } + + @Test + public void enablingDisablingTest() { + + // It is enabled by default + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(OllamaChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(OllamaChatClient.class)).isNotEmpty(); + }); + + // Explicitly enable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.ollama.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(OllamaChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(OllamaChatClient.class)).isNotEmpty(); + }); + + // Explicitly disable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.ollama.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(OllamaChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(OllamaChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(OllamaChatClient.class)).isEmpty(); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java similarity index 50% rename from spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java rename to spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index 75d7fa9f93c..3731ac2dc8e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -17,6 +17,7 @@ package org.springframework.ai.autoconfigure.ollama; import java.io.IOException; +import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -27,45 +28,53 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.springframework.ai.ollama.client.OllamaClient; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.ollama.OllamaEmbeddingClient; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import static org.assertj.core.api.Assertions.assertThat; -@Disabled("As it downloads the 3GB 'orca-mini' it can take couple of minutes to initialize.") +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +@Disabled("For manual smoke testing only.") @Testcontainers -public class OllamaAutoConfigurationIT { +public class OllamaEmbeddingAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(OllamaEmbeddingAutoConfigurationIT.class); - private static final Log logger = LogFactory.getLog(OllamaAutoConfigurationIT.class); + private static String MODEL_NAME = "orca-mini"; @Container - static GenericContainer ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.10").withExposedPorts(11434); + static GenericContainer ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.16").withExposedPorts(11434); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.ollama.baseUrl=http://" + ollamaContainer.getHost() + ":" - + ollamaContainer.getMappedPort(11434), "spring.ai.ollama.model=orca-mini") - .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); + static String baseUrl; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", "orca-mini"); - logger.info("orca-mini pulling competed!"); + logger.info("Start pulling the '" + MODEL_NAME + " ' model ... would take several minutes ..."); + ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); + logger.info(MODEL_NAME + " pulling competed!"); + + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); } + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.ollama.embedding.enabled=true", "spring.ai.ollama.embedding.model=" + MODEL_NAME, + "spring.ai.ollama.embedding.base-url=" + baseUrl) + .withConfiguration(AutoConfigurations.of(OllamaEmbeddingAutoConfiguration.class)); + @Test - void generate() { + public void singleTextEmbedding() { contextRunner.run(context -> { - OllamaClient client = context.getBean(OllamaClient.class); - assertThat(client.getBaseUrl()) - .isEqualTo("http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434)); - assertThat(client.getModel()).isEqualTo("orca-mini"); - - String response = client.generate("Hello"); - - assertThat(response).isNotEmpty(); - logger.info("Response: " + response); + OllamaEmbeddingClient embeddingClient = context.getBean(OllamaEmbeddingClient.class); + assertThat(embeddingClient).isNotNull(); + EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); + assertThat(embeddingResponse.getData()).hasSize(1); + assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingClient.dimensions()).isEqualTo(3200); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java new file mode 100644 index 00000000000..e4c3d0295a1 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.ollama.OllamaEmbeddingClient; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +public class OllamaEmbeddingAutoConfigurationTests { + + @Test + public void propertiesTest() { + + new ApplicationContextRunner().withPropertyValues("spring.ai.ollama.embedding.enabled=true", + "spring.ai.ollama.embedding.base-url=TEST_BASE_URL", "spring.ai.ollama.embedding.model=MODEL_XYZ", + "spring.ai.ollama.embedding.options.temperature=0.13" // TODO: Fix the + // float parsing + ).withConfiguration(AutoConfigurations.of(OllamaEmbeddingAutoConfiguration.class)).run(context -> { + var properties = context.getBean(OllamaEmbeddingProperties.class); + + // java.lang.Float.valueOf(0.13f) + assertThat(properties.isEnabled()).isTrue(); + assertThat(properties.getModel()).isEqualTo("MODEL_XYZ"); + assertThat(properties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(properties.getOptions()).containsKeys("temperature"); + assertThat(properties.getOptions().get("temperature")).isEqualTo("0.13"); + + }); + } + + @Test + public void enablingDisablingTest() { + + // It is enabled by default + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(OllamaEmbeddingAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(OllamaEmbeddingClient.class)).isNotEmpty(); + }); + + // Explicitly enable the embedding auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.ollama.embedding.enabled=true") + .withConfiguration(AutoConfigurations.of(OllamaEmbeddingAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(OllamaEmbeddingClient.class)).isNotEmpty(); + }); + + // Explicitly disable the embedding auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.ollama.embedding.enabled=false") + .withConfiguration(AutoConfigurations.of(OllamaEmbeddingAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(OllamaEmbeddingClient.class)).isEmpty(); + }); + } + +}