diff --git a/README.md b/README.md
index eff3dcf5fc8..467415dcaf2 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,15 @@ Let's make your `@Beans` intelligent!
:warning:
### Breaking Changes
+
+December 20, 2023 Update
+
+Refactor the Ollama client and related classes and package names
+
+ - Replace the org.springframework.ai.ollama.client.OllamaClient by org.springframework.ai.ollama.OllamaChatClient.
+ - The OllamaChatClient method signatures have changed.
+ - Rename the org.springframework.ai.autoconfigure.ollama.OllamaProperties into org.springframework.ai.autoconfigure.ollama.OllamaChatProperties and change the suffix to: `spring.ai.ollama.chat`. Some of the properties have changed as well.
+
December 19, 2023 Update
Renaming of AiClient and related classes and packagenames
@@ -22,7 +31,7 @@ Renaming of AiClient and related classes and packagenames
* Rename AiStreamClient to StreamingChatClient
* Rename package org.sf.ai.client to org.sf.ai.chat
-Rename artifact ID of
+Rename artifact ID of
* `transformers-embedding` to `spring-ai-transformers`
@@ -144,7 +153,7 @@ Following vector stores are supported:
```xml
- org.springframework.ai
+ org.springframework.ai
spring-ai-azure-vector-store-spring-boot-starter
0.8.0-SNAPSHOT
@@ -154,7 +163,7 @@ Following vector stores are supported:
```xml
- org.springframework.ai
+ org.springframework.ai
spring-ai-chroma-store-spring-boot-starter
0.8.0-SNAPSHOT
@@ -163,7 +172,7 @@ Following vector stores are supported:
```xml
- org.springframework.ai
+ org.springframework.ai
spring-ai-milvus-store-spring-boot-starter
0.8.0-SNAPSHOT
@@ -173,7 +182,7 @@ Following vector stores are supported:
```xml
- org.springframework.ai
+ org.springframework.ai
spring-ai-pgvector-store-spring-boot-starter
0.8.0-SNAPSHOT
@@ -182,7 +191,7 @@ Following vector stores are supported:
```xml
- org.springframework.ai
+ org.springframework.ai
spring-ai-pinecone-store-spring-boot-starter
0.8.0-SNAPSHOT
@@ -191,7 +200,7 @@ Following vector stores are supported:
```xml
- org.springframework.ai
+ org.springframework.ai
spring-ai-weaviate-store-spring-boot-starter
0.8.0-SNAPSHOT
@@ -200,7 +209,7 @@ Following vector stores are supported:
```xml
- org.springframework.ai
+ org.springframework.ai
spring-ai-neo4j-store-spring-boot-starter
0.8.0-SNAPSHOT
diff --git a/models/spring-ai-ollama/README.md b/models/spring-ai-ollama/README.md
new file mode 100644
index 00000000000..d87417ed6b0
--- /dev/null
+++ b/models/spring-ai-ollama/README.md
@@ -0,0 +1,120 @@
+# 1. Ollama Chat and Embedding
+
+## 1.1 OllamaApi
+
+[OllamaApi](./src/main/java/org/springframework/ai/ollama/api/OllamaApi.java) provides is lightweight Java client for [Ollama models](https://ollama.ai/).
+
+The OllamaApi provides the Chat completion as well as Embedding endpoints.
+
+Following class diagram illustrates the OllamaApi interface and building blocks for chat completion:
+
+
+
+The OllamaApi can supports all [Ollama Models](https://ollama.ai/library) providing synchronous chat completion, streaming chat completion and embedding:
+
+```java
+ChatResponse chat(ChatRequest chatRequest)
+
+Flux streamingChat(ChatRequest chatRequest)
+
+EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest)
+```
+
+> NOTE: OllamaApi expose also the Ollama `generation` endpoint but later if inferior compared to the Ollama `chat` endpoint.
+
+The `OllamaApiOptions` is helper class used as type-safe option builder. It provides `toMap` to convert the content into `Map`.
+
+Here is a simple snippet how to use the OllamaApi programmatically:
+
+```java
+var request = ChatRequest.builder("orca-mini")
+ .withStream(false)
+ .withMessages(List.of(Message.builder(Role.user)
+ .withContent("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?")
+ .build()))
+ .withOptions(Options.builder().withTemperature(0.9f).build())
+ .build();
+
+ChatResponse response = ollamaApi.chat(request);
+```
+
+```java
+var request = ChatRequest.builder("orca-mini")
+ .withStream(true)
+ .withMessages(List.of(Message.builder(Role.user)
+ .withContent("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?")
+ .build()))
+ .withOptions(Options.builder().withTemperature(0.9f).build().toMap())
+ .build();
+
+Flux response = ollamaApi.streamingChat(request);
+
+List responses = response.collectList().block();
+```
+
+```java
+EmbeddingRequest request = new EmbeddingRequest("orca-mini", "I like to eat apples");
+
+EmbeddingResponse response = ollamaApi.embeddings(request);
+```
+
+## 1.2 OllamaChatClient and OllamaEmbeddingClient
+
+The [OllamaChatClient](./src/main/java/org/springframework/ai/ollama/OllamaChatClient.java) implements the Spring-Ai `ChatClient` and `StreamingChatClient` interfaces.
+
+The [OllamaEmbeddingClient](./src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java) implements the Spring AI `EmbeddingClient` interface.
+
+Both the OllamaChatClient and the OllamaEmbeddingClient leverage the `OllamaApi`.
+
+You can configure the clients like this:
+
+```java
+@Bean
+public OllamaApi ollamaApi() {
+ return new OllamaApi(baseUrl);
+}
+
+@Bean
+public OllamaChatClient ollamaChat(OllamaApi ollamaApi) {
+ return new OllamaChatClient(ollamaApi).withModel(MODEL)
+ .withOptions(OllamaApiOptions.Options.builder().withTemperature(0.9f).build());
+}
+
+@Bean
+public OllamaEmbeddingClient ollamaEmbedding(OllamaApi ollamaApi) {
+ return new OllamaEmbeddingClient(ollamaApi).withModel("orca-mini");
+}
+
+```
+
+or you can leverage the `spring-ai-ollama-spring-boot-starter` Spring Boot starter.
+For this add the following dependency:
+
+```xml
+
+ spring-ai-ollama-spring-boot-starter
+ org.springframework.ai
+ 0.8.0-SNAPSHOT
+
+```
+
+Use the `OllamaChatProperties` to configure the Ollama Chat client:
+
+| Property | Description | Default |
+| ------------- | ------------- | ------------- |
+| spring.ai.ollama.chat.model | Model to use. | llama2 |
+| spring.ai.ollama.chat.base-url | The base url of the Ollama server. | http://localhost:11434 |
+| spring.ai.ollama.chat.enabled | Allows you to disable the Ollama Chat autoconfiguration. | true |
+| spring.ai.ollama.chat.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.8 |
+| spring.ai.ollama.chat.topP | The maximum cumulative probability of tokens to consider when sampling. | - |
+| spring.ai.ollama.chat.topK | Max number or responses to generate. | - |
+| spring.ai.options.chat.options (WIP) | A Map used to configure the Chat client. | - |
+
+and `OllamaEmbeddingProperties` to configure the Ollama Embedding client:
+
+| Property | Description | Default |
+| ------------- | ------------- | ------------- |
+| spring.ai.ollama.embedding.model | Model to use. | llama2 |
+| spring.ai.ollama.embedding.base-url | The base url of the Ollama server. | http://localhost:11434 |
+| spring.ai.ollama.embedding.enabled | Allows you to disable the Ollama embedding autoconfiguration. | true |
+| spring.ai.options.embedding.options (WIP) | A Map used to configure the embedding client. | - |
diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml
index fead80fc019..0da9e9a5353 100644
--- a/models/spring-ai-ollama/pom.xml
+++ b/models/spring-ai-ollama/pom.xml
@@ -1,7 +1,6 @@
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4.0.0
org.springframework.ai
@@ -28,6 +27,23 @@
${project.parent.version}
+
+ org.springframework
+ spring-webflux
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ ${jackson.version}
+
+
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jsr310
+ ${jackson.version}
+
+
org.springframework.boot
spring-boot-starter-logging
@@ -39,5 +55,17 @@
spring-boot-starter-test
test
+
+
+ org.springframework.boot
+ spring-boot-testcontainers
+ test
+
+
+
+ org.testcontainers
+ junit-jupiter
+ test
+
\ No newline at end of file
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java
new file mode 100644
index 00000000000..517734205f5
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.ollama;
+
+import java.util.List;
+import java.util.Map;
+
+import reactor.core.publisher.Flux;
+
+import org.springframework.ai.chat.ChatClient;
+import org.springframework.ai.chat.ChatResponse;
+import org.springframework.ai.chat.Generation;
+import org.springframework.ai.chat.StreamingChatClient;
+import org.springframework.ai.metadata.ChoiceMetadata;
+import org.springframework.ai.metadata.Usage;
+import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
+import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
+import org.springframework.ai.ollama.api.OllamaApiOptions;
+import org.springframework.ai.prompt.Prompt;
+import org.springframework.ai.prompt.messages.Message;
+import org.springframework.ai.prompt.messages.MessageType;
+
+/**
+ * @author Christian Tzolov
+ * @since 0.8.0
+ */
+public class OllamaChatClient implements ChatClient, StreamingChatClient {
+
+ private final OllamaApi chatApi;
+
+ private String model = "orca-mini";
+
+ private Map clientOptions;
+
+ public OllamaChatClient(OllamaApi chatApi) {
+ this.chatApi = chatApi;
+ }
+
+ public OllamaChatClient withModel(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public OllamaChatClient withOptions(Map options) {
+ this.clientOptions = options;
+ return this;
+ }
+
+ public OllamaChatClient withOptions(OllamaApiOptions.Options options) {
+ this.clientOptions = options.toMap();
+ return this;
+ }
+
+ @Override
+ public ChatResponse generate(Prompt prompt) {
+
+ OllamaApi.ChatResponse response = this.chatApi.chat(request(prompt, this.model, false));
+ var generator = new Generation(response.message().content());
+ if (response.promptEvalCount() != null && response.evalCount() != null) {
+ generator = generator.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(response)));
+ }
+ return new ChatResponse(List.of(new Generation(response.message().content())));
+ }
+
+ @Override
+ public Flux generateStream(Prompt prompt) {
+
+ Flux response = this.chatApi.streamingChat(request(prompt, this.model, true));
+
+ return response.map(chunk -> {
+ Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content())
+ : new Generation("");
+ if (chunk.done()) {
+ generation = generation.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(chunk)));
+ }
+ return new ChatResponse(List.of(generation));
+ });
+ }
+
+ private Usage extractUsage(OllamaApi.ChatResponse response) {
+ return new Usage() {
+
+ @Override
+ public Long getPromptTokens() {
+ return response.promptEvalCount().longValue();
+ }
+
+ @Override
+ public Long getGenerationTokens() {
+ return response.evalCount().longValue();
+ }
+ };
+ }
+
+ private OllamaApi.ChatRequest request(Prompt prompt, String model, boolean stream) {
+
+ List ollamaMessages = prompt.getMessages()
+ .stream()
+ .filter(message -> message.getMessageType() == MessageType.USER
+ || message.getMessageType() == MessageType.ASSISTANT)
+ .map(m -> OllamaApi.Message.builder(toRole(m)).withContent(m.getContent()).build())
+ .toList();
+
+ return ChatRequest.builder(model)
+ .withStream(stream)
+ .withMessages(ollamaMessages)
+ .withOptions(this.clientOptions)
+ .build();
+ }
+
+ private OllamaApi.Message.Role toRole(Message message) {
+
+ switch (message.getMessageType()) {
+ case USER:
+ return Role.user;
+ case ASSISTANT:
+ return Role.assistant;
+ case SYSTEM:
+ return Role.system;
+ default:
+ throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java
new file mode 100644
index 00000000000..c9385ebb203
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.ollama;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.AbstractEmbeddingClient;
+import org.springframework.ai.embedding.Embedding;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.api.OllamaApi.EmbeddingRequest;
+import org.springframework.ai.ollama.api.OllamaApiOptions;
+import org.springframework.util.Assert;
+
+/**
+ * @author Christian Tzolov
+ */
+public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
+
+ private final OllamaApi ollamaApi;
+
+ private String model = "orca-mini";
+
+ private Map clientOptions;
+
+ public OllamaEmbeddingClient(OllamaApi ollamaApi) {
+ this.ollamaApi = ollamaApi;
+ }
+
+ public OllamaEmbeddingClient withModel(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public OllamaEmbeddingClient withOptions(Map options) {
+ this.clientOptions = options;
+ return this;
+ }
+
+ public OllamaEmbeddingClient withOptions(OllamaApiOptions.Options options) {
+ this.clientOptions = options.toMap();
+ return this;
+ }
+
+ @Override
+ public List embed(String text) {
+ return this.embed(List.of(text)).iterator().next();
+ }
+
+ @Override
+ public List embed(Document document) {
+ return embed(document.getContent());
+ }
+
+ @Override
+ public List> embed(List texts) {
+ Assert.notEmpty(texts, "At least one text is required!");
+ Assert.isTrue(texts.size() == 1, "Ollama Embedding does not support batch embedding!");
+
+ String inputContent = texts.iterator().next();
+
+ OllamaApi.EmbeddingResponse response = this.ollamaApi
+ .embeddings(new EmbeddingRequest(this.model, inputContent, this.clientOptions));
+
+ return List.of(response.embedding());
+ }
+
+ @Override
+ public EmbeddingResponse embedForResponse(List texts) {
+ var indexCounter = new AtomicInteger(0);
+ List embeddings = this.embed(texts)
+ .stream()
+ .map(e -> new Embedding(e, indexCounter.getAndIncrement()))
+ .toList();
+ return new EmbeddingResponse(embeddings);
+ }
+
+}
\ No newline at end of file
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java
new file mode 100644
index 00000000000..c102284a268
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java
@@ -0,0 +1,592 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.ollama.api;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
+import org.springframework.ai.ollama.api.OllamaApiOptions.Options;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.http.client.ClientHttpResponse;
+import org.springframework.util.Assert;
+import org.springframework.util.StreamUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClient;
+import org.springframework.web.reactive.function.client.WebClient;
+
+/**
+ * Java Client for the Ollama API. https://ollama.ai/
+ *
+ * @author Christian Tzolov
+ * @since 0.8.0
+ */
+// @formatter:off
+public class OllamaApi {
+
+ private static final Log logger = LogFactory.getLog(OllamaApi.class);
+
+ private final static String DEFAULT_BASE_URL = "http://localhost:11434";
+
+ private final ResponseErrorHandler responseErrorHandler;
+
+ private final RestClient restClient;
+
+ private final WebClient webClient;
+
+ private static class OllamaResponseErrorHandler implements ResponseErrorHandler {
+
+ @Override
+ public boolean hasError(ClientHttpResponse response) throws IOException {
+ return response.getStatusCode().isError();
+ }
+
+ @Override
+ public void handleError(ClientHttpResponse response) throws IOException {
+ if (response.getStatusCode().isError()) {
+ int statusCode = response.getStatusCode().value();
+ String statusText = response.getStatusText();
+ String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8);
+ logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message));
+ throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message));
+ }
+ }
+
+ }
+
+ /**
+ * Default constructor that uses the default localhost url.
+ */
+ public OllamaApi() {
+ this(DEFAULT_BASE_URL);
+ }
+
+ /**
+ * Crate a new OllamaApi instance with the given base url.
+ * @param baseUrl The base url of the Ollama server.
+ */
+ public OllamaApi(String baseUrl) {
+ this(baseUrl, RestClient.builder());
+ }
+
+ /**
+ * Crate a new OllamaApi instance with the given base url and
+ * {@link RestClient.Builder}.
+ * @param baseUrl The base url of the Ollama server.
+ * @param restClientBuilder The {@link RestClient.Builder} to use.
+ */
+ public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder) {
+
+ this.responseErrorHandler = new OllamaResponseErrorHandler();
+
+ Consumer defaultHeaders = headers -> {
+ headers.setContentType(MediaType.APPLICATION_JSON);
+ headers.setAccept(List.of(MediaType.APPLICATION_JSON));
+ };
+
+ this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
+
+ this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
+ }
+
+ // --------------------------------------------------------------------------
+ // Generate & Streaming Generate
+ // --------------------------------------------------------------------------
+ /**
+ * The request object sent to the /generate endpoint.
+ *
+ * @param model (required) The model to use for completion.
+ * @param prompt (required) The prompt(s) to generate completions for.
+ * @param format (optional) The format to return the response in. Currently the only
+ * accepted value is "json".
+ * @param options (optional) additional model parameters listed in the documentation
+ * for the Modelfile such as temperature.
+ * @param system (optional) system prompt to (overrides what is defined in the Modelfile).
+ * @param template (optional) the full prompt or prompt template (overrides what is
+ * defined in the Modelfile).
+ * @param context the context parameter returned from a previous request to /generate,
+ * this can be used to keep a short conversational memory.
+ * @param stream (optional) if false the response will be returned as a single
+ * response object, rather than a stream of objects.
+ * @param raw (optional) if true no formatting will be applied to the prompt and no
+ * context will be returned. You may choose to use the raw parameter if you are
+ * specifying a full templated prompt in your request to the API, and are managing
+ * history yourself.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record GenerateRequest(
+ @JsonProperty("model") String model,
+ @JsonProperty("prompt") String prompt,
+ @JsonProperty("format") String format,
+ @JsonProperty("options") Map options,
+ @JsonProperty("system") String system,
+ @JsonProperty("template") String template,
+ @JsonProperty("context") List context,
+ @JsonProperty("stream") Boolean stream,
+ @JsonProperty("raw") Boolean raw) {
+
+ /**
+ * Short cut constructor to create a CompletionRequest without options.
+ * @param model The model used for completion.
+ * @param prompt The prompt(s) to generate completions for.
+ * @param stream Whether to stream the response.
+ */
+ public GenerateRequest(String model, String prompt, Boolean stream) {
+ this(model, prompt, null, null, null, null, null, stream, null);
+ }
+
+ /**
+ * Short cut constructor to create a CompletionRequest without options.
+ * @param model The model used for completion.
+ * @param prompt The prompt(s) to generate completions for.
+ * @param enableJsonFormat Whether to return the response in json format.
+ * @param stream Whether to stream the response.
+ */
+ public GenerateRequest(String model, String prompt, boolean enableJsonFormat, Boolean stream) {
+ this(model, prompt, (enableJsonFormat) ? "json" : null, null, null, null, null, stream, null);
+ }
+
+ /**
+ * Create a CompletionRequest builder.
+ * @param prompt The prompt(s) to generate completions for.
+ */
+ public static Builder builder(String prompt) {
+ return new Builder(prompt);
+ }
+
+ public static class Builder {
+
+ private String model;
+ private final String prompt;
+ private String format;
+ private Map options;
+ private String system;
+ private String template;
+ private List context;
+ private Boolean stream;
+ private Boolean raw;
+
+ public Builder(String prompt) {
+ this.prompt = prompt;
+ }
+
+ public Builder withModel(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public Builder withFormat(String format) {
+ this.format = format;
+ return this;
+ }
+
+ public Builder withOptions(Map options) {
+ this.options = options;
+ return this;
+ }
+
+ public Builder withOptions(Options options) {
+ this.options = options.toMap();
+ return this;
+ }
+
+ public Builder withSystem(String system) {
+ this.system = system;
+ return this;
+ }
+
+ public Builder withTemplate(String template) {
+ this.template = template;
+ return this;
+ }
+
+ public Builder withContext(List context) {
+ this.context = context;
+ return this;
+ }
+
+ public Builder withStream(Boolean stream) {
+ this.stream = stream;
+ return this;
+ }
+
+ public Builder withRaw(Boolean raw) {
+ this.raw = raw;
+ return this;
+ }
+
+ public GenerateRequest build() {
+ return new GenerateRequest(model, prompt, format, options, system, template, context, stream, raw);
+ }
+
+ }
+ }
+
+ /**
+ * The response object returned from the /generate endpoint. To calculate how fast the
+ * response is generated in tokens per second (token/s), divide eval_count /
+ * eval_duration.
+ *
+ * @param model The model used for completion.
+ * @param createdAt When the request was made.
+ * @param response The completion response. Empty if the response was streamed, if not
+ * streamed, this will contain the full response
+ * @param done Whether this is the final response. If true, this response may be
+ * followed by another response with the following, additional fields: context,
+ * prompt_eval_count, prompt_eval_duration, eval_count, eval_duration.
+ * @param context Encoding of the conversation used in this response, this can be sent
+ * in the next request to keep a conversational memory.
+ * @param totalDuration Time spent generating the response.
+ * @param loadDuration Time spent loading the model.
+ * @param promptEvalCount Number of times the prompt was evaluated.
+ * @param promptEvalDuration Time spent evaluating the prompt.
+ * @param evalCount Number of tokens in the response.
+ * @param evalDuration Time spent generating the response.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record GenerateResponse(
+ @JsonProperty("model") String model,
+ @JsonProperty("created_at") Instant createdAt,
+ @JsonProperty("response") String response,
+ @JsonProperty("done") Boolean done,
+ @JsonProperty("context") List context,
+ @JsonProperty("total_duration") Duration totalDuration,
+ @JsonProperty("load_duration") Duration loadDuration,
+ @JsonProperty("prompt_eval_count") Integer promptEvalCount,
+ @JsonProperty("prompt_eval_duration") Duration promptEvalDuration,
+ @JsonProperty("eval_count") Integer evalCount,
+ @JsonProperty("eval_duration") Duration evalDuration) {
+ }
+
+ /**
+ * Generate a completion for the given prompt.
+ * @param completionRequest Completion request.
+ * @return Completion response.
+ */
+ public GenerateResponse generate(GenerateRequest completionRequest) {
+ Assert.notNull(completionRequest, "The request body can not be null.");
+ Assert.isTrue(completionRequest.stream() == false, "Stream mode must be disabled.");
+
+ return this.restClient.post()
+ .uri("/api/generate")
+ .body(completionRequest)
+ .retrieve()
+ .onStatus(this.responseErrorHandler)
+ .body(GenerateResponse.class);
+ }
+
+ /**
+ * Generate a streaming completion for the given prompt.
+ * @param completionRequest Completion request. The request must set the stream
+ * property to true.
+ * @return Completion response as a {@link Flux} stream.
+ */
+ public Flux generateStreaming(GenerateRequest completionRequest) {
+ Assert.notNull(completionRequest, "The request body can not be null.");
+ Assert.isTrue(completionRequest.stream(), "Request must set the steam property to true.");
+
+ return webClient.post()
+ .uri("/api/generate")
+ .body(Mono.just(completionRequest), GenerateRequest.class)
+ .retrieve()
+ .bodyToFlux(GenerateResponse.class)
+ .handle((data, sink) -> {
+ System.out.println(data);
+ sink.next(data);
+ });
+ }
+
+ // --------------------------------------------------------------------------
+ // Chat & Streaming Chat
+ // --------------------------------------------------------------------------
+ /**
+ * Chat message object.
+ *
+ * @param role The role of the message of type {@link Role}.
+ * @param content The content of the message.
+ * @param images The list of images to send with the message.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Message(
+ @JsonProperty("role") Role role,
+ @JsonProperty("content") String content,
+ @JsonProperty("images") List images) {
+
+ /**
+ * The role of the message in the conversation.
+ */
+ public enum Role {
+
+ /**
+ * System message type used as instructions to the model.
+ */
+ system,
+ /**
+ * User message type.
+ */
+ user,
+ /**
+ * Assistant message type. Usually the response from the model.
+ */
+ assistant;
+
+ }
+
+ public static Builder builder(Role role) {
+ return new Builder(role);
+ }
+
+ public static class Builder {
+
+ private final Role role;
+ private String content;
+ private List images;
+
+ public Builder(Role role) {
+ this.role = role;
+ }
+
+ public Builder withContent(String content) {
+ this.content = content;
+ return this;
+ }
+
+ public Builder withImages(List images) {
+ this.images = images;
+ return this;
+ }
+
+ public Message build() {
+ return new Message(role, content, images);
+ }
+
+ }
+ }
+
+ /**
+ * Chat request object.
+ *
+ * @param model The model to use for completion.
+ * @param messages The list of messages to chat with.
+ * @param stream Whether to stream the response.
+ * @param format The format to return the response in. Currently the only accepted
+ * value is "json".
+ * @param options Additional model parameters. You can use the {@link Options} builder
+ * to create the options then {@link Options#toMap()} to convert the options into a
+ * map.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record ChatRequest(
+ @JsonProperty("model") String model,
+ @JsonProperty("messages") List messages,
+ @JsonProperty("stream") Boolean stream,
+ @JsonProperty("format") String format,
+ @JsonProperty("options") Map options) {
+
+ public static Builder builder(String model) {
+ return new Builder(model);
+ }
+
+ public static class Builder {
+
+ private final String model;
+ private List messages = List.of();
+ private boolean stream = false;
+ private String format;
+ private Map options = Map.of();
+
+ public Builder(String model) {
+ Assert.notNull(model, "The model can not be null.");
+ this.model = model;
+ }
+
+ public Builder withMessages(List messages) {
+ this.messages = messages;
+ return this;
+ }
+
+ public Builder withStream(boolean stream) {
+ this.stream = stream;
+ return this;
+ }
+
+ public Builder withFormat(String format) {
+ this.format = format;
+ return this;
+ }
+
+ public Builder withOptions(Map options) {
+ this.options = options;
+ return this;
+ }
+
+ public Builder withOptions(Options options) {
+ this.options = options.toMap();
+ return this;
+ }
+
+ public ChatRequest build() {
+ return new ChatRequest(model, messages, stream, format, options);
+ }
+
+ }
+ }
+
+ /**
+ * Ollama chat response object.
+ *
+ * @param model The model name used for completion.
+ * @param createdAt When the request was made.
+ * @param message The response {@link Message} with {@link Message.Role#assistant}.
+ * @param done Whether this is the final response. For streaming response only the
+ * last message is marked as done. If true, this response may be followed by another
+ * response with the following, additional fields: context, prompt_eval_count,
+ * prompt_eval_duration, eval_count, eval_duration.
+ * @param totalDuration Time spent generating the response.
+ * @param loadDuration Time spent loading the model.
+ * @param promptEvalCount number of tokens in the prompt.(*)
+ * @param promptEvalDuration time spent evaluating the prompt.
+ * @param evalCount number of tokens in the response.
+ * @param evalDuration time spent generating the response.
+ * @see Chat
+ * Completion API
+ * @see Ollama
+ * Types
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record ChatResponse(
+ @JsonProperty("model") String model,
+ @JsonProperty("created_at") Instant createdAt,
+ @JsonProperty("message") Message message,
+ @JsonProperty("done") Boolean done,
+ @JsonProperty("total_duration") Duration totalDuration,
+ @JsonProperty("load_duration") Duration loadDuration,
+ @JsonProperty("prompt_eval_count") Integer promptEvalCount,
+ @JsonProperty("prompt_eval_duration") Duration promptEvalDuration,
+ @JsonProperty("eval_count") Integer evalCount,
+ @JsonProperty("eval_duration") Duration evalDuration) {
+ }
+
+ /**
+ * Generate the next message in a chat with a provided model.
+ *
+ * This is a streaming endpoint (controlled by the 'stream' request property), so
+ * there will be a series of responses. The final response object will include
+ * statistics and additional data from the request.
+ * @param chatRequest Chat request.
+ * @return Chat response.
+ */
+ public ChatResponse chat(ChatRequest chatRequest) {
+ Assert.notNull(chatRequest, "The request body can not be null.");
+ Assert.isTrue(!chatRequest.stream(), "Stream mode must be disabled.");
+
+ return this.restClient.post()
+ .uri("/api/chat")
+ .body(chatRequest)
+ .retrieve()
+ .onStatus(this.responseErrorHandler)
+ .body(ChatResponse.class);
+ }
+
+ /**
+ * Streaming response for the chat completion request.
+ * @param chatRequest Chat request. The request must set the stream property to true.
+ * @return Chat response as a {@link Flux} stream.
+ */
+ public Flux streamingChat(ChatRequest chatRequest) {
+ Assert.notNull(chatRequest, "The request body can not be null.");
+ Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true.");
+
+ return webClient.post()
+ .uri("/api/chat")
+ .body(Mono.just(chatRequest), GenerateRequest.class)
+ .retrieve()
+ .bodyToFlux(ChatResponse.class)
+ .handle((data, sink) -> {
+ System.out.println(data);
+ sink.next(data);
+ });
+ }
+
+ // --------------------------------------------------------------------------
+ // Embeddings
+ // --------------------------------------------------------------------------
+ /**
+ * Generate embeddings from a model.
+ *
+ * @param model The name of model to generate embeddings from.
+ * @param prompt The text to generate embeddings for.
+ * @param options Additional model parameters listed in the documentation for the
+ * Modelfile such as temperature.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record EmbeddingRequest(
+ @JsonProperty("model") String model,
+ @JsonProperty("prompt") String prompt,
+ @JsonProperty("options") Map options) {
+
+ /**
+ * short cut constructor to create a EmbeddingRequest without options.
+ * @param model The name of model to generate embeddings from.
+ * @param prompt The text to generate embeddings for.
+ */
+ public EmbeddingRequest(String model, String prompt) {
+ this(model, prompt, null);
+ }
+ }
+
+ /**
+ * The response object returned from the /embedding endpoint.
+ *
+ * @param embedding The embedding generated from the model.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record EmbeddingResponse(
+ @JsonProperty("embedding") List embedding) {
+ }
+
+ /**
+ * Generate embeddings from a model.
+ * @param embeddingRequest Embedding request.
+ * @return Embedding response.
+ */
+ public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) {
+ Assert.notNull(embeddingRequest, "The request body can not be null.");
+
+ return this.restClient.post()
+ .uri("/api/embeddings")
+ .body(embeddingRequest)
+ .retrieve()
+ .onStatus(this.responseErrorHandler)
+ .body(EmbeddingResponse.class);
+ }
+
+}
+// @formatter:on
\ No newline at end of file
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiOptions.java
new file mode 100644
index 00000000000..2f0beeba49d
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiOptions.java
@@ -0,0 +1,395 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.ollama.api;
+
+import java.util.Map;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+/**
+ * Helper class for building strongly typed the Ollama request options.
+ *
+ * @author Christian Tzolov
+ * @since 0.8.0
+ */
+// @formatter:on
+public class OllamaApiOptions {
+
+ /**
+ * Runner options which must be set when the model is loaded into memory.
+ *
+ * @param useNUMA Whether to use NUMA.
+ * @param numCtx Sets the size of the context window used to generate the next token.
+ * (Default: 2048)
+ * @param numBatch ???
+ * @param numGQA The number of GQA groups in the transformer layer. Required for some
+ * models, for example it is 8 for llama2:70b.
+ * @param numGPU The number of layers to send to the GPU(s). On macOS it defaults to 1
+ * to enable metal support, 0 to disable.
+ * @param mainGPU ???
+ * @param lowVRAM ???
+ * @param f16KV ???
+ * @param logitsAll ???
+ * @param vocabOnly ???
+ * @param useMMap ???
+ * @param useMLock ???
+ * @param embeddingOnly ???
+ * @param ropeFrequencyBase ???
+ * @param ropeFrequencyScale ???
+ * @param numThread Sets the number of threads to use during computation. By default,
+ * Ollama will detect this for optimal performance. It is recommended to set this
+ * value to the number of physical CPU cores your system has (as opposed to the
+ * logical number of cores).
+ *
+ * Options specified in GenerateRequest.
+ * @param mirostat Enable Mirostat sampling for controlling perplexity. (default: 0, 0
+ * = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
+ * @param mirostatTau Influences how quickly the algorithm responds to feedback from
+ * the generated text. A lower learning rate will result in slower adjustments, while
+ * a higher learning rate will make the algorithm more responsive. (Default: 0.1).
+ * @param mirostatEta Controls the balance between coherence and diversity of the
+ * output. A lower value will result in more focused and coherent text. (Default:
+ * 5.0).
+ * @param numKeep Unknown.
+ * @param seed Sets the random number seed to use for generation. Setting this to a
+ * specific number will make the model generate the same text for the same prompt.
+ * (Default: 0)
+ * @param numPredict Maximum number of tokens to predict when generating text.
+ * (Default: 128, -1 = infinite generation, -2 = fill context)
+ * @param topK Reduces the probability of generating nonsense. A higher value (e.g.
+ * 100) will give more diverse answers, while a lower value (e.g. 10) will be more
+ * conservative. (Default: 40)
+ * @param topP Works together with top-k. A higher value (e.g., 0.95) will lead to
+ * more diverse text, while a lower value (e.g., 0.5) will generate more focused and
+ * conservative text. (Default: 0.9)
+ * @param tfsZ Tail free sampling is used to reduce the impact of less probable tokens
+ * from the output. A higher value (e.g., 2.0) will reduce the impact more, while a
+ * value of 1.0 disables this setting. (default: 1)
+ * @param typicalP Unknown.
+ * @param repeatLastN Sets how far back for the model to look back to prevent
+ * repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
+ * @param temperature The temperature of the model. Increasing the temperature will
+ * make the model answer more creatively. (Default: 0.8)
+ * @param repeatPenalty Sets how strongly to penalize repetitions. A higher value
+ * (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g.,
+ * 0.9) will be more lenient. (Default: 1.1)
+ * @param presencePenalty Unknown.
+ * @param frequencyPenalty Unknown.
+ * @param penalizeNewline Unknown.
+ * @param stop Sets the stop sequences to use. When this pattern is encountered the
+ * LLM will stop generating text and return. Multiple stop patterns may be set by
+ * specifying multiple separate stop parameters in a modelfile.
+ * @see Ollama
+ * Valid Parameters and Values
+ * @see Ollama Go
+ * Types
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Options(
+ // Runner options which must be set when the model is loaded into memory.
+ @JsonProperty("numa") Boolean useNUMA, @JsonProperty("num_ctx") Integer numCtx,
+ @JsonProperty("num_batch") Integer numBatch, @JsonProperty("num_gqa") Integer numGQA,
+ @JsonProperty("num_gpu") Integer numGPU, @JsonProperty("main_gpu") Integer mainGPU,
+ @JsonProperty("low_vram") Boolean lowVRAM, @JsonProperty("f16_kv") Boolean f16KV,
+ @JsonProperty("logits_all") Boolean logitsAll, @JsonProperty("vocab_only") Boolean vocabOnly,
+ @JsonProperty("use_mmap") Boolean useMMap, @JsonProperty("use_mlock") Boolean useMLock,
+ @JsonProperty("embedding_only") Boolean embeddingOnly,
+ @JsonProperty("rope_frequency_base") Float ropeFrequencyBase,
+ @JsonProperty("rope_frequency_scale") Float ropeFrequencyScale,
+ @JsonProperty("num_thread") Integer numThread,
+
+ // Options specified in GenerateRequest.
+ @JsonProperty("num_keep") Integer numKeep, @JsonProperty("seed") Integer seed,
+ @JsonProperty("num_predict") Integer numPredict, @JsonProperty("top_k") Integer topK,
+ @JsonProperty("top_p") Float topP, @JsonProperty("tfs_z") Float tfsZ,
+ @JsonProperty("typical_p") Float typicalP, @JsonProperty("repeat_last_n") Integer repeatLastN,
+ @JsonProperty("temperature") Float temperature, @JsonProperty("repeat_penalty") Float repeatPenalty,
+ @JsonProperty("presence_penalty") Float presencePenalty,
+ @JsonProperty("frequency_penalty") Float frequencyPenalty, @JsonProperty("mirostat") Integer mirostat,
+ @JsonProperty("mirostat_tau") Float mirostatTau, @JsonProperty("mirostat_eta") Float mirostatEta,
+ @JsonProperty("penalize_newline") Boolean penalizeNewline, @JsonProperty("stop") String[] stop) {
+
+ /**
+ * Convert the {@link Options} object to a {@link Map} of key/value pairs.
+ * @return The {@link Map} of key/value pairs.
+ */
+ public Map toMap() {
+ try {
+ var json = new ObjectMapper().writeValueAsString(this);
+ return new ObjectMapper().readValue(json, new TypeReference