diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java
index 4ddef5c5d66..750546d88ad 100644
--- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java
@@ -47,8 +47,10 @@
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
-import org.springframework.ai.ollama.api.OllamaModelPuller;
+import org.springframework.ai.ollama.management.ModelManagementOptions;
+import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
+import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -59,10 +61,9 @@
/**
* {@link ChatModel} implementation for {@literal Ollama}. Ollama allows developers to run
* large language models and generate embeddings locally. It supports open-source models
- * available on [Ollama AI Library](...). - Llama
- * 2 (7B parameters, 3.8GB size) - Mistral (7B parameters, 4.1GB size) Please refer to the
- * official Ollama website for the most up-to-date
- * information on available models.
+ * available on [Ollama AI Library](...) and on
+ * Hugging Face. Please refer to the official Ollama
+ * website for the most up-to-date information on available models.
*
* @author Christian Tzolov
* @author luocongqiu
@@ -73,57 +74,33 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
- /**
- * Low-level Ollama API library.
- */
private final OllamaApi chatApi;
- /**
- * Default options to be used for all chat requests.
- */
private final OllamaOptions defaultOptions;
- /**
- * Observation registry used for instrumentation.
- */
private final ObservationRegistry observationRegistry;
- /**
- * Conventions to use for generating observations.
- */
- private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
-
- private final OllamaModelPuller modelPuller;
-
- public OllamaChatModel(OllamaApi ollamaApi) {
- this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
- }
-
- public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
- this(ollamaApi, defaultOptions, null);
- }
+ private final OllamaModelManager modelManager;
- public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
- FunctionCallbackContext functionCallbackContext) {
- this(ollamaApi, defaultOptions, functionCallbackContext, List.of());
- }
+ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
- FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks) {
- this(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks, ObservationRegistry.NOOP);
- }
-
- public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks,
- ObservationRegistry observationRegistry) {
+ ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
- Assert.notNull(chatApi, "ollamaApi must not be null");
+ Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
- Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
- this.chatApi = chatApi;
+ Assert.notNull(observationRegistry, "observationRegistry must not be null");
+ Assert.notNull(observationRegistry, "modelManagementOptions must not be null");
+ this.chatApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
- this.modelPuller = new OllamaModelPuller(chatApi);
+ this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions);
+ initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
+ }
+
+ public static Builder builder() {
+ return new Builder();
}
@Override
@@ -324,9 +301,9 @@ else if (message instanceof ToolResponseMessage toolMessage) {
}
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
- mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
- if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
- mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
+ mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
+ if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
+ mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}
// Override the model.
@@ -353,9 +330,7 @@ else if (message instanceof ToolResponseMessage toolMessage) {
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
}
- if (mergedOptions.isPullMissingModel()) {
- this.modelPuller.pullModel(mergedOptions.getModel(), true);
- }
+ initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());
return requestBuilder.build();
}
@@ -400,6 +375,15 @@ public ChatOptions getDefaultOptions() {
return OllamaOptions.fromOptions(this.defaultOptions);
}
+ /**
+ * Pull the given model into Ollama based on the specified strategy.
+ */
+ private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
+ if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
+ this.modelManager.pullModel(model, pullModelStrategy);
+ }
+ }
+
/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
@@ -409,4 +393,58 @@ public void setObservationConvention(ChatModelObservationConvention observationC
this.observationConvention = observationConvention;
}
+ public static class Builder {
+
+ private OllamaApi ollamaApi;
+
+ private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
+
+ private FunctionCallbackContext functionCallbackContext;
+
+ private List toolFunctionCallbacks = List.of();
+
+ private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
+ private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
+
+ private Builder() {
+ }
+
+ public Builder withOllamaApi(OllamaApi ollamaApi) {
+ this.ollamaApi = ollamaApi;
+ return this;
+ }
+
+ public Builder withDefaultOptions(OllamaOptions defaultOptions) {
+ this.defaultOptions = defaultOptions;
+ return this;
+ }
+
+ public Builder withFunctionCallbackContext(FunctionCallbackContext functionCallbackContext) {
+ this.functionCallbackContext = functionCallbackContext;
+ return this;
+ }
+
+ public Builder withToolFunctionCallbacks(List toolFunctionCallbacks) {
+ this.toolFunctionCallbacks = toolFunctionCallbacks;
+ return this;
+ }
+
+ public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
+ this.observationRegistry = observationRegistry;
+ return this;
+ }
+
+ public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
+ this.modelManagementOptions = modelManagementOptions;
+ return this;
+ }
+
+ public OllamaChatModel build() {
+ return new OllamaChatModel(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks,
+ observationRegistry, modelManagementOptions);
+ }
+
+ }
+
}
\ No newline at end of file
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java
index e918cc9e662..6da6714ee80 100644
--- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java
@@ -22,7 +22,6 @@
import java.util.regex.Pattern;
import io.micrometer.observation.ObservationRegistry;
-import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
@@ -32,24 +31,20 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
-import org.springframework.ai.ollama.api.OllamaModelPuller;
+import org.springframework.ai.ollama.management.ModelManagementOptions;
+import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
+import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
- * {@link EmbeddingModel} implementation for {@literal Ollama}.
- *
- * Ollama allows developers to run large language models and generate embeddings locally.
- * It supports open-source models available on [Ollama AI
- * Library](https://ollama.ai/library).
- *
- * Examples of models supported: - Llama 2 (7B parameters, 3.8GB size) - Mistral (7B
- * parameters, 4.1GB size)
- *
- * Please refer to the official Ollama website for the
- * most up-to-date information on available models.
+ * {@link EmbeddingModel} implementation for {@literal Ollama}. Ollama allows developers
+ * to run large language models and generate embeddings locally. It supports open-source
+ * models available on [Ollama AI Library](...)
+ * and on Hugging Face. Please refer to the official Ollama
+ * website for the most up-to-date information on available models.
*
* @author Christian Tzolov
* @author Thomas Vitale
@@ -61,41 +56,31 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
private final OllamaApi ollamaApi;
- /**
- * Default options to be used for all chat requests.
- */
private final OllamaOptions defaultOptions;
- /**
- * Observation registry used for instrumentation.
- */
private final ObservationRegistry observationRegistry;
- /**
- * Conventions to use for generating observations.
- */
- private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
-
- private final OllamaModelPuller modelPuller;
+ private final OllamaModelManager modelManager;
- public OllamaEmbeddingModel(OllamaApi ollamaApi) {
- this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
- }
-
- public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
- this(ollamaApi, defaultOptions, ObservationRegistry.NOOP);
- }
+ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
- ObservationRegistry observationRegistry) {
- Assert.notNull(ollamaApi, "openAiApi must not be null");
+ ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
+ Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "options must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
+ Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
this.ollamaApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
- this.modelPuller = new OllamaModelPuller(ollamaApi);
+ this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
+
+ initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
+ }
+
+ public static Builder builder() {
+ return new Builder();
}
@Override
@@ -153,9 +138,9 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List inputContent, Em
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
- mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
- if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
- mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
+ mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
+ if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
+ mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}
// Override the model.
@@ -164,9 +149,7 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List inputContent, Em
}
String model = mergedOptions.getModel();
- if (mergedOptions.isPullMissingModel()) {
- this.modelPuller.pullModel(model, true);
- }
+ initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());
return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
@@ -176,6 +159,15 @@ private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request
return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
}
+ /**
+ * Pull the given model into Ollama based on the specified strategy.
+ */
+ private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
+ if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
+ this.modelManager.pullModel(model, pullModelStrategy);
+ }
+ }
+
/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
@@ -216,4 +208,43 @@ public static Duration parse(String input) {
}
+ public static class Builder {
+
+ private OllamaApi ollamaApi;
+
+ private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
+
+ private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
+ private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
+
+ private Builder() {
+ }
+
+ public Builder withOllamaApi(OllamaApi ollamaApi) {
+ this.ollamaApi = ollamaApi;
+ return this;
+ }
+
+ public Builder withDefaultOptions(OllamaOptions defaultOptions) {
+ this.defaultOptions = defaultOptions;
+ return this;
+ }
+
+ public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
+ this.observationRegistry = observationRegistry;
+ return this;
+ }
+
+ public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
+ this.modelManagementOptions = modelManagementOptions;
+ return this;
+ }
+
+ public OllamaEmbeddingModel build() {
+ return new OllamaEmbeddingModel(ollamaApi, defaultOptions, observationRegistry, modelManagementOptions);
+ }
+
+ }
+
}
\ No newline at end of file
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java
index 7cac43455fd..acd9028d1b7 100644
--- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java
@@ -879,6 +879,7 @@ public record ShowModelResponse(
* Show information about a model available locally on the machine where Ollama is running.
*/
public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
+ Assert.notNull(showModelRequest, "showModelRequest must not be null");
return this.restClient.post()
.uri("/api/show")
.body(showModelRequest)
@@ -897,6 +898,7 @@ public record CopyModelRequest(
* Copy a model. Creates a model with another name from an existing model.
*/
public ResponseEntity copyModel(CopyModelRequest copyModelRequest) {
+ Assert.notNull(copyModelRequest, "copyModelRequest must not be null");
return this.restClient.post()
.uri("/api/copy")
.body(copyModelRequest)
@@ -914,6 +916,7 @@ public record DeleteModelRequest(
* Delete a model and its data.
*/
public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) {
+ Assert.notNull(deleteModelRequest, "deleteModelRequest must not be null");
return this.restClient.method(HttpMethod.DELETE)
.uri("/api/delete")
.body(deleteModelRequest)
@@ -925,20 +928,20 @@ public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) {
@JsonInclude(Include.NON_NULL)
public record PullModelRequest(
@JsonProperty("model") String model,
- @JsonProperty("insecure") Boolean insecure,
+ @JsonProperty("insecure") boolean insecure,
@JsonProperty("username") String username,
@JsonProperty("password") String password,
- @JsonProperty("stream") Boolean stream
+ @JsonProperty("stream") boolean stream
) {
public PullModelRequest {
- if (stream != null && stream) {
- logger.warn("Streaming when pulling models is not supported yet");
+ if (!stream) {
+ logger.warn("Enforcing streaming of the model pull request");
}
- stream = false;
+ stream = true;
}
public PullModelRequest(String model) {
- this(model, null, null, null, null);
+ this(model, false, null, null, true);
}
}
@@ -954,13 +957,15 @@ public record ProgressResponse(
* Download a model from the Ollama library. Cancelled pulls are resumed from where they left off,
* and multiple calls will share the same download progress.
*/
- public ProgressResponse pullModel(PullModelRequest pullModelRequest) {
- return this.restClient.post()
+ public Flux pullModel(PullModelRequest pullModelRequest) {
+ Assert.notNull(pullModelRequest, "pullModelRequest must not be null");
+ Assert.isTrue(pullModelRequest.stream(), "Request must set the stream property to true.");
+
+ return this.webClient.post()
.uri("/api/pull")
- .body(pullModelRequest)
+ .bodyValue(pullModelRequest)
.retrieve()
- .onStatus(this.responseErrorHandler)
- .body(ProgressResponse.class);
+ .bodyToFlux(ProgressResponse.class);
}
}
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModelPuller.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModelPuller.java
deleted file mode 100644
index bbf8fefdb31..00000000000
--- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModelPuller.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
-* Copyright 2024 - 2024 the original author or authors.
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* https://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-package org.springframework.ai.ollama.api;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest;
-import org.springframework.ai.ollama.api.OllamaApi.ListModelResponse;
-import org.springframework.ai.ollama.api.OllamaApi.PullModelRequest;
-import org.springframework.http.HttpStatus;
-import org.springframework.util.CollectionUtils;
-
-/**
- * Helper class that allow to check if a model is available locally and pull it if not.
- *
- * @author Christian Tzolov
- * @since 1.0.0
- */
-public class OllamaModelPuller {
-
- private final Logger logger = LoggerFactory.getLogger(OllamaModelPuller.class);
-
- private OllamaApi ollamaApi;
-
- private final long pullRetryTimeoutMs;
-
- public OllamaModelPuller(OllamaApi ollamaApi) {
- this(ollamaApi, 5000);
- }
-
- public OllamaModelPuller(OllamaApi ollamaApi, long retryTimeoutMs) {
- this.ollamaApi = ollamaApi;
- this.pullRetryTimeoutMs = retryTimeoutMs;
- }
-
- public boolean isModelAvailable(String modelName) {
- ListModelResponse modelsResponse = ollamaApi.listModels();
- if (!CollectionUtils.isEmpty(modelsResponse.models())) {
- return modelsResponse.models().stream().anyMatch(m -> m.name().equals(modelName));
- }
- return false;
- }
-
- public boolean deleteModel(String modelName) {
- logger.info("Delete model: {}", modelName);
- if (!isModelAvailable(modelName)) {
- logger.info("Model: {} not found!", modelName);
- return false;
- }
- return this.ollamaApi.deleteModel(new DeleteModelRequest(modelName)).getStatusCode().equals(HttpStatus.OK);
- }
-
- public String pullModel(String modelName, boolean enablePullRetry) {
- String status = "";
- do {
- logger.info("Start Pulling model: {}", modelName);
- var progress = this.ollamaApi.pullModel(new PullModelRequest(modelName));
- status = progress.status();
- logger.info("Pulling model: {} - Status: {}", modelName, status);
-
- try {
- Thread.sleep(this.pullRetryTimeoutMs);
- }
- catch (InterruptedException e) {
- e.printStackTrace();
- }
- }
- while (enablePullRetry && !status.equals("success"));
-
- return status;
- }
-
-}
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java
index d6f2f85cdac..b25b5310230 100644
--- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java
@@ -28,6 +28,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
+import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.util.Assert;
@@ -304,8 +305,11 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
@JsonIgnore
private Map toolContext;
+ /**
+ * Strategy for pulling models at run-time.
+ */
@JsonIgnore
- private boolean pullMissingModel;
+ private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER;
public static OllamaOptions builder() {
return new OllamaOptions();
@@ -519,8 +523,8 @@ public OllamaOptions withToolContext(Map toolContext) {
return this;
}
- public OllamaOptions withPullMissingModel(boolean pullMissingModel) {
- this.pullMissingModel = pullMissingModel;
+ public OllamaOptions withPullModelStrategy(PullModelStrategy pullModelStrategy) {
+ this.pullModelStrategy = pullModelStrategy;
return this;
}
@@ -864,12 +868,12 @@ public void setToolContext(Map toolContext) {
this.toolContext = toolContext;
}
- public Boolean isPullMissingModel() {
- return this.pullMissingModel;
+ public PullModelStrategy getPullModelStrategy() {
+ return this.pullModelStrategy;
}
- public void setPullMissingModel(boolean pullMissingModel) {
- this.pullMissingModel = pullMissingModel;
+ public void setPullModelStrategy(PullModelStrategy pullModelStrategy) {
+ this.pullModelStrategy = pullModelStrategy;
}
/**
@@ -943,7 +947,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
.withProxyToolCalls(fromOptions.getProxyToolCalls())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withToolContext(fromOptions.getToolContext())
- .withPullMissingModel(fromOptions.isPullMissingModel());
+ .withPullModelStrategy(fromOptions.getPullModelStrategy());
}
// @formatter:on
@@ -974,7 +978,7 @@ public boolean equals(Object o) {
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions)
&& Objects.equals(toolContext, that.toolContext)
- && Objects.equals(pullMissingModel, that.pullMissingModel);
+ && Objects.equals(pullModelStrategy, that.pullModelStrategy);
}
@Override
@@ -985,7 +989,7 @@ public int hashCode() {
this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls,
- this.toolContext, this.pullMissingModel);
+ this.toolContext, this.pullModelStrategy);
}
}
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java
new file mode 100644
index 00000000000..92676d6e8ea
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.ollama.management;
+
+import java.time.Duration;
+
+/**
+ * Options for managing models in Ollama.
+ *
+ * @author Thomas Vitale
+ * @since 1.0.0
+ */
+public record ModelManagementOptions(PullModelStrategy pullModelStrategy, Duration timeout, Integer maxRetries) {
+ public static ModelManagementOptions defaults() {
+ return new ModelManagementOptions(PullModelStrategy.NEVER, Duration.ofMinutes(5), 0);
+ }
+}
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java
new file mode 100644
index 00000000000..01d444b94d5
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java
@@ -0,0 +1,108 @@
+/*
+* Copyright 2024 - 2024 the original author or authors.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* https://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.springframework.ai.ollama.management;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest;
+import org.springframework.ai.ollama.api.OllamaApi.ListModelResponse;
+import org.springframework.ai.ollama.api.OllamaApi.PullModelRequest;
+import org.springframework.util.CollectionUtils;
+import reactor.util.retry.Retry;
+
+import java.time.Duration;
+
+/**
+ * Manage the lifecycle of models in Ollama.
+ *
+ * @author Christian Tzolov
+ * @author Thomas Vitale
+ * @since 1.0.0
+ */
+public class OllamaModelManager {
+
+ private final Logger logger = LoggerFactory.getLogger(OllamaModelManager.class);
+
+ private final OllamaApi ollamaApi;
+
+ private final ModelManagementOptions options;
+
+ public OllamaModelManager(OllamaApi ollamaApi) {
+ this(ollamaApi, ModelManagementOptions.defaults());
+ }
+
+ public OllamaModelManager(OllamaApi ollamaApi, ModelManagementOptions options) {
+ this.ollamaApi = ollamaApi;
+ this.options = options;
+ }
+
+ public boolean isModelAvailable(String modelName) {
+ ListModelResponse listModelResponse = ollamaApi.listModels();
+ if (!CollectionUtils.isEmpty(listModelResponse.models())) {
+ // Not an equality check to support the implicit ":latest" tag.
+ return listModelResponse.models().stream().anyMatch(m -> m.name().contains(modelName));
+ }
+ return false;
+ }
+
+ public void deleteModel(String modelName) {
+ logger.info("Start deletion of model: {}", modelName);
+ if (!isModelAvailable(modelName)) {
+ logger.info("Model {} not found", modelName);
+ return;
+ }
+ this.ollamaApi.deleteModel(new DeleteModelRequest(modelName));
+ logger.info("Completed deletion of model: {}", modelName);
+ }
+
+ public void pullModel(String modelName) {
+ pullModel(modelName, options.pullModelStrategy());
+ }
+
+ public void pullModel(String modelName, PullModelStrategy pullModelStrategy) {
+ if (PullModelStrategy.NEVER.equals(pullModelStrategy)) {
+ return;
+ }
+
+ if (PullModelStrategy.WHEN_MISSING.equals(pullModelStrategy)) {
+ if (isModelAvailable(modelName)) {
+ logger.debug("Model '{}' already available. Skipping pull operation.", modelName);
+ return;
+ }
+ }
+
+ // @formatter:off
+
+ logger.info("Start pulling model: {}", modelName);
+ this.ollamaApi.pullModel(new PullModelRequest(modelName))
+ .bufferUntilChanged(OllamaApi.ProgressResponse::status)
+ .doOnEach(signal -> {
+ var progressResponses = signal.get();
+ if (!CollectionUtils.isEmpty(progressResponses) && progressResponses.get(progressResponses.size() - 1) != null) {
+ logger.info("Pulling the '{}' model - Status: {}", modelName, progressResponses.get(progressResponses.size() - 1).status());
+ }
+ })
+ .takeUntil(progressResponses -> progressResponses.get(0) != null && progressResponses.get(0).status().equals("success"))
+ .timeout(options.timeout())
+ .retryWhen(Retry.backoff(options.maxRetries(), Duration.ofSeconds(5)))
+ .blockLast();
+ logger.info("Completed pulling the '{}' model", modelName);
+
+ // @formatter:on
+ }
+
+}
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java
new file mode 100644
index 00000000000..11be453aaba
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.ollama.management;
+
+/**
+ * Strategy for pulling Ollama models.
+ *
+ * @author Thomas Vitale
+ * @since 1.0.0
+ */
+public enum PullModelStrategy {
+
+ /**
+ * Always pull the model, even if it's already available. Useful to ensure you're
+ * using the latest version of that model.
+ */
+ ALWAYS,
+
+ /**
+ * Only pull the model if it's not already available. It might be an older version of
+ * the model.
+ */
+ WHEN_MISSING,
+
+ /**
+ * Never pull the model.
+ */
+ NEVER;
+
+}
diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java
new file mode 100644
index 00000000000..dc7eed369f4
--- /dev/null
+++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@NonNullApi
+@NonNullFields
+package org.springframework.ai.ollama.management;
+
+import org.springframework.lang.NonNullApi;
+import org.springframework.lang.NonNullFields;
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java
index be5982d5fdc..f3d4ef07aa3 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java
@@ -1,14 +1,12 @@
package org.springframework.ai.ollama;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.management.OllamaModelManager;
+import org.springframework.ai.ollama.management.PullModelStrategy;
import org.testcontainers.ollama.OllamaContainer;
public class BaseOllamaIT {
- private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class);
-
// Toggle for running tests locally on native Ollama for a faster feedback loop.
private static final boolean useTestcontainers = true;
@@ -46,9 +44,8 @@ public static OllamaApi buildOllamaApiWithModel(String model) {
}
public static void ensureModelIsPresent(OllamaApi ollamaApi, String model) {
- logger.info("Start pulling the '{}' model. The operation can take several minutes...", model);
- ollamaApi.pullModel(new OllamaApi.PullModelRequest(model));
- logger.info("Completed pulling the '{}' model", model);
+ var ollamaModelManager = new OllamaModelManager(ollamaApi);
+ ollamaModelManager.pullModel(model, PullModelStrategy.WHEN_MISSING);
}
}
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java
index a3933b047d3..ae9d51fac8e 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java
@@ -52,7 +52,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class);
- private static final String MODEL = OllamaModel.LLAMA3_1.getName();
+ private static final String MODEL = "qwen2.5:3b";
@Autowired
ChatModel chatModel;
@@ -60,7 +60,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
@Test
void functionCallTest() {
UserMessage userMessage = new UserMessage(
- "What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");
+ "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
List messages = new ArrayList<>(List.of(userMessage));
@@ -68,7 +68,8 @@ void functionCallTest() {
.withModel(MODEL)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
- .withDescription("Get the weather in location")
+ .withDescription(
+ "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
@@ -84,7 +85,7 @@ void functionCallTest() {
@Test
void streamFunctionCallTest() {
UserMessage userMessage = new UserMessage(
- "What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");
+ "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
List messages = new ArrayList<>(List.of(userMessage));
@@ -92,7 +93,8 @@ void streamFunctionCallTest() {
.withModel(MODEL)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
- .withDescription("Get the weather in location")
+ .withDescription(
+ "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
@@ -122,7 +124,10 @@ public OllamaApi ollamaApi() {
@Bean
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
- return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9));
+ return OllamaChatModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
+ .build();
}
}
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java
index a622dbcc17b..3b23f39b123 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java
@@ -33,8 +33,9 @@
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
-import org.springframework.ai.ollama.api.OllamaModelPuller;
+import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
+import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
@@ -63,19 +64,24 @@ class OllamaChatModelIT extends BaseOllamaIT {
@Test
void autoPullModelTest() {
- var puller = new OllamaModelPuller(ollamaApi);
- puller.deleteModel("tinyllama");
-
- assertThat(puller.isModelAvailable("tinyllama")).isFalse();
+ var modelManager = new OllamaModelManager(ollamaApi);
+ var model = "tinyllama";
+ modelManager.deleteModel(model);
+ assertThat(modelManager.isModelAvailable(model)).isFalse();
String joke = ChatClient.create(chatModel)
.prompt("Tell me a joke")
- .options(OllamaOptions.builder().withModel("tinyllama").withPullMissingModel(true).build())
+ .options(OllamaOptions.builder()
+ .withModel(model)
+ .withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
+ .build())
.call()
.content();
assertThat(joke).isNotEmpty();
- assertThat(puller.isModelAvailable("tinyllamaf")).isFalse();
+ assertThat(modelManager.isModelAvailable(model)).isTrue();
+
+ modelManager.deleteModel(model);
}
@Test
@@ -240,7 +246,10 @@ public OllamaApi ollamaApi() {
@Bean
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
- return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9));
+ return OllamaChatModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
+ .build();
}
}
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java
index 8536d61a716..5d8956552c4 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java
@@ -45,7 +45,7 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT {
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelMultimodalIT.class);
- private static final String MODEL = OllamaModel.MOONDREAM.getName();
+ private static final String MODEL = "llava-phi3";
@Autowired
private OllamaChatModel chatModel;
@@ -54,7 +54,7 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT {
void unsupportedMediaType() {
var imageData = new ClassPathResource("/norway.webp");
- var userMessage = new UserMessage("Explain what do you see on this picture?",
+ var userMessage = new UserMessage("Explain what do you see in this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt(List.of(userMessage))));
@@ -64,7 +64,7 @@ void unsupportedMediaType() {
void multiModalityTest() {
var imageData = new ClassPathResource("/test.png");
- var userMessage = new UserMessage("Explain what do you see on this picture?",
+ var userMessage = new UserMessage("Explain what do you see in this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
var response = chatModel.call(new Prompt(List.of(userMessage)));
@@ -83,7 +83,10 @@ public OllamaApi ollamaApi() {
@Bean
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
- return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9));
+ return OllamaChatModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
+ .build();
}
}
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java
index 327418a4afe..4e9f803132c 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java
@@ -24,7 +24,6 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.ollama.api.OllamaApi;
@@ -169,9 +168,11 @@ public OllamaApi openAiApi() {
}
@Bean
- public OllamaChatModel openAiChatModel(OllamaApi openAiApi, TestObservationRegistry observationRegistry) {
- return new OllamaChatModel(openAiApi, OllamaOptions.create(), new FunctionCallbackContext(), List.of(),
- observationRegistry);
+ public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
+ return OllamaChatModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withObservationRegistry(observationRegistry)
+ .build();
}
}
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java
index b9b7bc2e6ef..cdb829953be 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java
@@ -26,11 +26,15 @@
/**
* @author Christian Tzolov
+ * @author Thomas Vitale
*/
public class OllamaChatRequestTests {
- OllamaChatModel chatModel = new OllamaChatModel(new OllamaApi(),
- new OllamaOptions().withModel("MODEL_NAME").withTopK(99).withTemperature(66.6).withNumGPU(1));
+ OllamaChatModel chatModel = OllamaChatModel.builder()
+ .withOllamaApi(new OllamaApi())
+ .withDefaultOptions(
+ OllamaOptions.create().withModel("MODEL_NAME").withTopK(99).withTemperature(66.6).withNumGPU(1))
+ .build();
@Test
public void createRequestWithDefaultOptions() {
@@ -104,8 +108,10 @@ public void createRequestWithPromptOptionsModelOverride() {
@Test
public void createRequestWithDefaultOptionsModelOverride() {
- OllamaChatModel chatModel = new OllamaChatModel(new OllamaApi(),
- new OllamaOptions().withModel("DEFAULT_OPTIONS_MODEL"));
+ OllamaChatModel chatModel = OllamaChatModel.builder()
+ .withOllamaApi(new OllamaApi())
+ .withDefaultOptions(OllamaOptions.create().withModel("DEFAULT_OPTIONS_MODEL"))
+ .build();
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java
index 3813017d177..f7ce804e0e9 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java
@@ -20,10 +20,10 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.api.OllamaApi;
-import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest;
import org.springframework.ai.ollama.api.OllamaModel;
-import org.springframework.ai.ollama.api.OllamaModelPuller;
+import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
+import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
@@ -66,33 +66,35 @@ void embeddings() {
@Test
void autoPullModel() {
+ var model = "all-minilm";
assertThat(embeddingModel).isNotNull();
- var puller = new OllamaModelPuller(ollamaApi);
- puller.deleteModel("all-minilm:latest");
-
- assertThat(puller.isModelAvailable("all-minilm")).isFalse();
+ var modelManager = new OllamaModelManager(ollamaApi);
+ modelManager.deleteModel(model);
+ assertThat(modelManager.isModelAvailable(model)).isFalse();
EmbeddingResponse embeddingResponse = embeddingModel
.call(new EmbeddingRequest(List.of("Hello World", "Something else"),
OllamaOptions.builder()
- .withModel("all-minilm:latest")
- .withPullMissingModel(true)
+ .withModel(model)
+ .withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
.withTruncate(false)
.build()));
- assertThat(puller.isModelAvailable("all-minilm:latest")).isTrue();
+ assertThat(modelManager.isModelAvailable(model)).isTrue();
assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
- assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("all-minilm:latest");
+ assertThat(embeddingResponse.getMetadata().getModel()).contains(model);
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
assertThat(embeddingModel.dimensions()).isEqualTo(768);
+
+ modelManager.deleteModel(model);
}
@SpringBootConfiguration
@@ -105,7 +107,10 @@ public OllamaApi ollamaApi() {
@Bean
public OllamaEmbeddingModel ollamaEmbedding(OllamaApi ollamaApi) {
- return new OllamaEmbeddingModel(ollamaApi, OllamaOptions.create().withModel(MODEL));
+ return OllamaEmbeddingModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withDefaultOptions(OllamaOptions.create().withModel(MODEL))
+ .build();
}
}
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java
index 6f43b2a17bb..aaf786ff24e 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java
@@ -103,9 +103,12 @@ public OllamaApi openAiApi() {
}
@Bean
- public OllamaEmbeddingModel openAiEmbeddingModel(OllamaApi openAiApi,
+ public OllamaEmbeddingModel openAiEmbeddingModel(OllamaApi ollamaApi,
TestObservationRegistry observationRegistry) {
- return new OllamaEmbeddingModel(openAiApi, OllamaOptions.builder().build(), observationRegistry);
+ return OllamaEmbeddingModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withObservationRegistry(observationRegistry)
+ .build();
}
}
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java
index 6b6569b8f08..b77be0a6049 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java
@@ -39,6 +39,7 @@
/**
* @author Christian Tzolov
+ * @author Thomas Vitale
* @since 1.0.0
*/
@ExtendWith(MockitoExtension.class)
@@ -62,7 +63,10 @@ public void options() {
// Tests default options
var defaultOptions = OllamaOptions.builder().withModel("DEFAULT_MODEL").build();
- var embeddingModel = new OllamaEmbeddingModel(ollamaApi, defaultOptions);
+ var embeddingModel = OllamaEmbeddingModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withDefaultOptions(defaultOptions)
+ .build();
EmbeddingResponse response = embeddingModel.call(
new EmbeddingRequest(List.of("Input1", "Input2", "Input3"), EmbeddingOptionsBuilder.builder().build()));
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java
index cff91717b19..dfa8c9f22d4 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java
@@ -29,8 +29,11 @@
*/
public class OllamaEmbeddingRequestTests {
- OllamaEmbeddingModel embeddingModel = new OllamaEmbeddingModel(new OllamaApi(),
- new OllamaOptions().withModel("DEFAULT_MODEL").withMainGPU(11).withUseMMap(true).withNumGPU(1));
+ OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder()
+ .withOllamaApi(new OllamaApi())
+ .withDefaultOptions(
+ OllamaOptions.create().withModel("DEFAULT_MODEL").withMainGPU(11).withUseMMap(true).withNumGPU(1))
+ .build();
@Test
public void ollamaEmbeddingRequestDefaultOptions() {
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java
index 91752bdb6dc..c6a7c67e1a7 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java
@@ -23,6 +23,7 @@
import org.testcontainers.junit.jupiter.Testcontainers;
import java.io.IOException;
+import java.time.Duration;
import static org.assertj.core.api.Assertions.assertThat;
@@ -84,8 +85,14 @@ public void pullModel() {
assertThat(listModelResponse.models().stream().anyMatch(model -> model.name().contains(MODEL))).isFalse();
var pullModelRequest = new OllamaApi.PullModelRequest(MODEL);
- var progressResponse = ollamaApi.pullModel(pullModelRequest);
- assertThat(progressResponse.status()).contains("success");
+ var progressResponses = ollamaApi.pullModel(pullModelRequest)
+ .timeout(Duration.ofMinutes(5))
+ .collectList()
+ .block();
+
+ assertThat(progressResponses).isNotNull();
+ assertThat(progressResponses.get(progressResponses.size() - 1))
+ .isEqualTo(new OllamaApi.ProgressResponse("success", null, null, null));
listModelResponse = ollamaApi.listModels();
assertThat(listModelResponse.models().stream().anyMatch(model -> model.name().contains(MODEL))).isTrue();
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java
index db374b5feea..dab9bd1799b 100644
--- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java
@@ -46,7 +46,7 @@
@DisabledIf("isDisabled")
public class OllamaApiToolFunctionCallIT extends BaseOllamaIT {
- private static final String MODEL = OllamaModel.LLAMA3_1.getName();
+ private static final String MODEL = "qwen2.5:3b";
private static final Logger logger = LoggerFactory.getLogger(OllamaApiToolFunctionCallIT.class);
@@ -64,11 +64,13 @@ public static void beforeAll() throws IOException, InterruptedException {
public void toolFunctionCall() {
// Step 1: send the conversation and available functions to the model
var message = Message.builder(Role.USER)
- .withContent("What's the weather like in San Francisco, Tokyo, and Paris? Return temperature in Celsius.")
+ .withContent(
+ "What's the weather like in San Francisco, Tokyo, and Paris? Return a list with the temperature in Celsius for each of the three locations.")
.build();
var functionTool = new OllamaApi.ChatRequest.Tool(new OllamaApi.ChatRequest.Tool.Function("getCurrentWeather",
- "Get the weather in location like city names.", ModelOptionsUtils.jsonToMap("""
+ "Find the current weather conditions, forecasts, and temperatures for a location, like a city or state.",
+ ModelOptionsUtils.jsonToMap("""
{
"type": "object",
"properties": {
diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/management/OllamaModelManagerIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/management/OllamaModelManagerIT.java
new file mode 100644
index 00000000000..dac65820f82
--- /dev/null
+++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/management/OllamaModelManagerIT.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.ollama.management;
+
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.DisabledIf;
+import org.springframework.ai.ollama.BaseOllamaIT;
+import org.springframework.ai.ollama.api.OllamaModel;
+import org.testcontainers.junit.jupiter.Testcontainers;
+
+import java.io.IOException;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration tests for {@link OllamaModelManager}.
+ *
+ * @author Thomas Vitale
+ */
+@Testcontainers
+@DisabledIf("isDisabled")
+class OllamaModelManagerIT extends BaseOllamaIT {
+
+ private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
+
+ static OllamaModelManager modelManager;
+
+ @BeforeAll
+ public static void beforeAll() throws IOException, InterruptedException {
+ var ollamaApi = buildOllamaApiWithModel(MODEL);
+ modelManager = new OllamaModelManager(ollamaApi);
+ }
+
+ @Test
+ public void whenModelAvailableReturnTrue() {
+ var isModelAvailable = modelManager.isModelAvailable(MODEL);
+ assertThat(isModelAvailable).isTrue();
+
+ isModelAvailable = modelManager.isModelAvailable(MODEL + ":latest");
+ assertThat(isModelAvailable).isTrue();
+ }
+
+ @Test
+ public void whenModelNotAvailableReturnFalse() {
+ var isModelAvailable = modelManager.isModelAvailable("aleph");
+ assertThat(isModelAvailable).isFalse();
+ }
+
+ @Test
+ public void pullAndDeleteModel() {
+ var model = "all-minilm";
+ modelManager.pullModel(model, PullModelStrategy.WHEN_MISSING);
+ var isModelAvailable = modelManager.isModelAvailable(model);
+ assertThat(isModelAvailable).isTrue();
+
+ modelManager.deleteModel(model);
+ isModelAvailable = modelManager.isModelAvailable(model);
+ assertThat(isModelAvailable).isFalse();
+
+ model = "all-minilm:latest";
+ modelManager.pullModel(model, PullModelStrategy.WHEN_MISSING);
+ isModelAvailable = modelManager.isModelAvailable(model);
+ assertThat(isModelAvailable).isTrue();
+
+ modelManager.deleteModel(model);
+ isModelAvailable = modelManager.isModelAvailable(model);
+ assertThat(isModelAvailable).isFalse();
+ }
+
+}
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc
index df4e088ad0a..80d6ac15bce 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc
@@ -8,12 +8,27 @@ The xref:_openai_api_compatibility[OpenAI API compatibility] section explains ho
== Prerequisites
-You first need to xref:https://ollama.com/download[Download and install Ollama] on your local machine.
+You first need access to an Ollama instance. There are a few options, including the following:
-Also you can pull the models you want to use from the xref:https://ollama.com/library[Ollama model repository]: `ollama pull `.
-Alternatively, you can enable the `pullMissingModel` option to automatically download missing models: xref:auto-pulling-models[Auto-pulling Models].
+* xref:https://ollama.com/download[Download and install Ollama] on your local machine.
+* Configure and xref:api/testcontainers.adoc[run Ollama via Testcontainers].
+* Bind to an Ollama instance via xref:api/cloud-bindings.adoc[Kubernetes Service Bindings].
-TIP: you can also pull, by name, any of the thousands, free, xref:https://huggingface.co/models?library=gguf&sort=trending[GGUF HuggingFace Models]
+You can pull the models you want to use in your application from the xref:https://ollama.com/library[Ollama model library]:
+
+[source,shellscript]
+----
+ollama pull
+----
+
+You can also pull any of the thousands, free, xref:https://huggingface.co/models?library=gguf&sort=trending[GGUF Hugging Face Models]:
+
+[source,shellscript]
+----
+ollama pull hf.co//
+----
+
+Alternatively, you can enable the option to download automatically any needed model: xref:auto-pulling-models[Auto-pulling Models].
== Auto-configuration
@@ -39,18 +54,29 @@ dependencies {
TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
-=== Chat Properties
+=== Base Properties
The prefix `spring.ai.ollama` is the property prefix to configure the connection to Ollama.
[cols="3,6,1"]
|====
| Property | Description | Default
-
| spring.ai.ollama.base-url | Base URL where Ollama API server is running. | `http://localhost:11434`
|====
-The prefix `spring.ai.ollama.chat.options` is the property prefix that configures the Ollama chat model .
+Here are the properties for initializing the Ollama integration and xref:auto-pulling-models[auto-pulling models].
+
+[cols="3,6,1"]
+|====
+| Property | Description | Default
+| spring.ai.ollama.init.pull-model-strategy | Whether to pull models at startup-time and how. | `never`
+| spring.ai.ollama.init.timeout | How long to wait for a model to be pulled. | `5m`
+| spring.ai.ollama.init.max-retries | Maximum number of retries for the model pull operation. | `0`
+|====
+
+=== Chat Properties
+
+The prefix `spring.ai.ollama.chat.options` is the property prefix that configures the Ollama chat model.
It includes the Ollama request (advanced) parameters such as the `model`, `keep-alive`, and `format` as well as the Ollama model `options` properties.
Here are the advanced request parameter for the Ollama chat model:
@@ -58,12 +84,11 @@ Here are the advanced request parameter for the Ollama chat model:
[cols="3,6,1"]
|====
| Property | Description | Default
-
| spring.ai.ollama.chat.enabled | Enable Ollama chat model. | true
| spring.ai.ollama.chat.options.model | The name of the https://github.com/ollama/ollama?tab=readme-ov-file#model-library[supported model] to use. | mistral
-| spring.ai.ollama.chat.options.pull-missing-model | Automatically pull missing models from Ollama repository | false
| spring.ai.ollama.chat.options.format | The format to return a response in. Currently, the only accepted value is `json` | -
| spring.ai.ollama.chat.options.keep_alive | Controls how long the model will stay loaded into memory following the request | 5m
+| spring.ai.ollama.chat.options.pull-model-strategy | Strategy for pulling models at run-time. | `never`
|====
The remaining `options` properties are based on the link:https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values[Ollama Valid Parameters and Values] and link:https://github.com/ollama/ollama/blob/main/api/types.go[Ollama Types]. The default values are based on the link:https://github.com/ollama/ollama/blob/b538dc3858014f94b099730a592751a5454cab0a/api/types.go#L364[Ollama Types Defaults].
@@ -101,7 +126,7 @@ The remaining `options` properties are based on the link:https://github.com/olla
| spring.ai.ollama.chat.options.penalize-newline | - | true
| spring.ai.ollama.chat.options.stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. | -
| spring.ai.ollama.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | -
-| spring.ai.ollama.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false
+| spring.ai.ollama.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false
|====
TIP: All properties prefixed with `spring.ai.ollama.chat.options` can be overridden at runtime by adding request-specific <> to the `Prompt` call.
@@ -130,28 +155,57 @@ ChatResponse response = chatModel.call(
TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java[OllamaOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()].
[[auto-pulling-models]]
-=== Auto-pulling Models
+== Auto-pulling Models
-The `pullMissingModel` option allows you to automatically download and use models that are not currently available on your local Ollama instance.
+Spring AI Ollama can automatically pull models when not available in your Ollama instance.
This feature is particularly useful when working with different models or when deploying your application to new environments.
-To enable auto-pulling of missing models, you can set the `pullMissingModel` option to `true` in your `OllamaOptions`:
+TIP: you can also pull, by name, any of the thousands, free, xref:https://huggingface.co/models?library=gguf&sort=trending[GGUF Hugging Face Models].
-[source,java]
+There are three strategies for pulling models:
+
+* `always` (defined in `PullModelStrategy.ALWAYS`). Always pull the model, even if it's already available. Useful to ensure you're using the latest version of that model.
+* `when_missing` (defined in `PullModelStrategy.WHEN_MISSING`). Only pull the model if it's not already available. It might be an older version of the model.
+* `never` (defined in `PullModelStrategy.NEVER`). Never pull the model.
+
+CAUTION: Due to the unexpected delays while downloading models, this feature is not recommended for production environments. Instead, consider to assess and pre-download the necessary models in advance.
+
+=== Pulling models at startup time
+
+All models defined via configuration properties and default options can be automatically pulled at startup time.
+You can configure strategy, timeout, and max number of retries via configuration properties.
+
+[source,yaml]
----
-OllamaOptions options = OllamaOptions.builder()
- .withModel("all-minilm:latest")
- .withPullMissingModel(true)
- .build();
+spring:
+ ai:
+ ollama:
+ init:
+ pull-model-strategy: always
+ timeout: 60s
+ max-retries: 1
----
-TIP: you can also pull, by name, any of the thousands, free, xref:https://huggingface.co/models?library=gguf&sort=trending[GGUF HuggingFace Models]
+CAUTION: The application will not complete its initialization until all the models become available in Ollama. Depending on the model size and the speed of the Internet connection, your application might be slow at starting up.
+
+=== Pulling models at runtime
-You can also configure this option using the following property: `spring.ai.ollama.chat.options.pull-missing-model=true`
+To enable auto-pulling of models at runtime, you can configure the `pullModelStrategy` option in your `OllamaOptions`:
+
+[source,java]
+----
+ChatResponse response = chatModel.call(new Prompt(
+ "Generate the names of 5 famous pirates.",
+ OllamaOptions.builder()
+ .withModel("llama3.2")
+ .withPullModelStrategy(PullModelStrategy.ALWAYS)
+ .build()
+ ));
+----
-When `pullMissingModel` is set to `true`, the system will attempt to download the specified model if it's not already available locally. This process may take some time depending on the size of the model and your internet connection speed.
+You can also configure this option using the following property: `spring.ai.ollama.chat.options.pull-model-strategy=always`.
-CAUTION: Be aware that enabling this option may lead to unexpected delays in your application if it needs to download large model files. It's recommended to pre-download commonly used models in production environments.
+CAUTION: The time to process an incoming request might incur unexpected delays, waiting for the needed model to become available in Ollama. Depending on the model size and the speed of the Internet connection, your application might be slow at processing requests.
== Function Calling
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc
index 15182bedd77..393147ea521 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc
@@ -11,10 +11,27 @@ TIP: you can also pull, by name, any of the thousands, free, xref:https://huggin
== Prerequisites
-You first need to xref:https://ollama.com/download[Download and install Ollama] on your local machine.
+You first need access to an Ollama instance. There are a few options, including the following:
-Also you can pull the models you want to use from the https://ollama.com/search?c=embedding[Ollama Embedding Models]: `ollama pull `.
-Alternatively, you can enable the `pullMissingModel` option to automatically download missing models: xref:auto-pulling-models[Auto-pulling Models].
+* xref:https://ollama.com/download[Download and install Ollama] on your local machine.
+* Configure and xref:api/testcontainers.adoc[run Ollama via Testcontainers].
+* Bind to an Ollama instance via xref:api/cloud-bindings.adoc[Kubernetes Service Bindings].
+
+You can pull the models you want to use in your application from the https://ollama.com/search?c=embedding[Ollama model library]:
+
+[source,shellscript]
+----
+ollama pull
+----
+
+You can also pull any of the thousands, free, xref:https://huggingface.co/models?library=gguf&sort=trending[GGUF Hugging Face Models]:
+
+[source,shellscript]
+----
+ollama pull hf.co//
+----
+
+Alternatively, you can enable the option to download automatically any needed model: xref:auto-pulling-models[Auto-pulling Models].
== Auto-configuration
@@ -41,18 +58,29 @@ dependencies {
TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the Repositories section to add these repositories to your build system.
-=== Embedding Properties
+=== Base Properties
The prefix `spring.ai.ollama` is the property prefix to configure the connection to Ollama
[cols="3,6,1"]
|====
| Property | Description | Default
-
| spring.ai.ollama.base-url | Base URL where Ollama API server is running. | `http://localhost:11434`
|====
-The prefix `spring.ai.ollama.embedding.options` is the property prefix that configures the Ollama embedding model .
+Here are the properties for initializing the Ollama integration and xref:auto-pulling-models[auto-pulling models].
+
+[cols="3,6,1"]
+|====
+| Property | Description | Default
+| spring.ai.ollama.init.pull-model-strategy | Whether to pull models at startup-time and how. | `never`
+| spring.ai.ollama.init.timeout | How long to wait for a model to be pulled. | `5m`
+| spring.ai.ollama.init.max-retries | Maximum number of retries for the model pull operation. | `0`
+|====
+
+=== Embedding Properties
+
+The prefix `spring.ai.ollama.embedding.options` is the property prefix that configures the Ollama embedding model.
It includes the Ollama request (advanced) parameters such as the `model`, `keep-alive`, and `truncate` as well as the Ollama model `options` properties.
Here are the advanced request parameter for the Ollama embedding model:
@@ -63,9 +91,9 @@ Here are the advanced request parameter for the Ollama embedding model:
| spring.ai.ollama.embedding.enabled | Enables the Ollama embedding model auto-configuration. | true
| spring.ai.ollama.embedding.options.model | The name of the https://github.com/ollama/ollama?tab=readme-ov-file#model-library[supported model] to use.
You can use dedicated https://ollama.com/search?c=embedding[Embedding Model] types | mistral
-| spring.ai.ollama.embedding.options.pull-missing-model | Automatically pull missing models from Ollama repository | false
| spring.ai.ollama.embedding.options.keep_alive | Controls how long the model will stay loaded into memory following the request | 5m
| spring.ai.ollama.embedding.options.truncate | Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. | true
+| spring.ai.ollama.embedding.options.pull-model-strategy | Strategy for pulling models at run-time. | `never`
|====
The remaining `options` properties are based on the link:https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values[Ollama Valid Parameters and Values] and link:https://github.com/ollama/ollama/blob/main/api/types.go[Ollama Types]. The default values are based on: link:https://github.com/ollama/ollama/blob/b538dc3858014f94b099730a592751a5454cab0a/api/types.go#L364[Ollama type defaults].
@@ -129,31 +157,56 @@ EmbeddingResponse embeddingResponse = embeddingModel.call(
----
[[auto-pulling-models]]
-=== Auto-pulling Models
+== Auto-pulling Models
-The `pullMissingModel` option allows you to automatically download and use models that are not currently available on your local Ollama instance.
+Spring AI Ollama can automatically pull models when not available in your Ollama instance.
This feature is particularly useful when working with different models or when deploying your application to new environments.
-To enable auto-pulling of missing models, you can set the `pullMissingModel` option to `true` in your `OllamaOptions`:
+TIP: you can also pull, by name, any of the thousands, free, xref:https://huggingface.co/models?library=gguf&sort=trending[GGUF Hugging Face Models].
+
+There are three strategies for pulling models:
+
+* `always` (defined in `PullModelStrategy.ALWAYS`). Always pull the model, even if it's already available. Useful to ensure you're using the latest version of that model.
+* `when_missing` (defined in `PullModelStrategy.WHEN_MISSING`). Only pull the model if it's not already available. It might be an older version of the model.
+* `never` (defined in `PullModelStrategy.NEVER`). Never pull the model.
+
+CAUTION: Due to the unexpected delays while downloading models, this feature is not recommended for production environments. Instead, consider to assess and pre-download the necessary models in advance.
+
+=== Pulling models at startup time
+
+All models defined via configuration properties and default options can be automatically pulled at startup time.
+You can configure strategy, timeout, and max number of retries via configuration properties.
+
+[source,yaml]
+----
+spring:
+ ai:
+ ollama:
+ init:
+ pull-model-strategy: always
+ timeout: 60s
+ max-retries: 1
+----
+
+CAUTION: The application will not complete its initialization until all the models become available in Ollama. Depending on the model size and the speed of the Internet connection, your application might be slow at starting up.
+
+=== Pulling models at runtime
+
+To enable auto-pulling of models at runtime, you can configure the `pullModelStrategy` option in your `OllamaOptions`:
[source,java]
----
EmbeddingResponse embeddingResponse = embeddingModel
.call(new EmbeddingRequest(List.of("Hello World", "Something else"),
OllamaOptions.builder()
- .withModel("all-minilm:latest")
- .withPullMissingModel(true)
- .withTruncate(false)
+ .withModel("all-minilm")
+ .withPullModelStrategy(PullModelStrategy.ALWAYS)
.build()));
----
-TIP: you can also pull, by name, any of the thousands, free, xref:https://huggingface.co/models?library=gguf&sort=trending[GGUF HuggingFace Models]
-
-You can also configure this option using the following property: `spring.ai.ollama.embedding.options.pull-missing-model=true`
-
-When `pullMissingModel` is set to `true`, the system will attempt to download the specified model if it's not already available locally. This process may take some time depending on the size of the model and your internet connection speed.
+You can also configure this option using the following property: `spring.ai.ollama.embedding.options.pull-model-strategy=always`.
-CAUTION: Be aware that enabling this option may lead to unexpected delays in your application if it needs to download large model files. It's recommended to pre-download commonly used models in production environments.
+CAUTION: The time to process an incoming request might incur unexpected delays, waiting for the needed model to become available in Ollama. Depending on the model size and the speed of the Internet connection, your application might be slow at processing requests.
== Sample Controller
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java
index 0e6bccb4816..46f6a292ddd 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java
@@ -24,6 +24,7 @@
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
@@ -51,7 +52,7 @@
@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(OllamaApi.class)
@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class,
- OllamaConnectionProperties.class })
+ OllamaConnectionProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class })
public class OllamaAutoConfiguration {
@@ -76,11 +77,18 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
@ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
- List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext,
- ObjectProvider observationRegistry,
+ OllamaInitializationProperties initProperties, List toolFunctionCallbacks,
+ FunctionCallbackContext functionCallbackContext, ObjectProvider observationRegistry,
ObjectProvider observationConvention) {
- var chatModel = new OllamaChatModel(ollamaApi, properties.getOptions(), functionCallbackContext,
- toolFunctionCallbacks, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
+ var chatModel = OllamaChatModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withDefaultOptions(properties.getOptions())
+ .withFunctionCallbackContext(functionCallbackContext)
+ .withToolFunctionCallbacks(toolFunctionCallbacks)
+ .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
+ .withModelManagementOptions(new ModelManagementOptions(initProperties.getPullModelStrategy(),
+ initProperties.getTimeout(), initProperties.getMaxRetries()))
+ .build();
observationConvention.ifAvailable(chatModel::setObservationConvention);
@@ -92,10 +100,15 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties
@ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
- ObjectProvider observationRegistry,
+ OllamaInitializationProperties initProperties, ObjectProvider observationRegistry,
ObjectProvider observationConvention) {
- var embeddingModel = new OllamaEmbeddingModel(ollamaApi, properties.getOptions(),
- observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
+ var embeddingModel = OllamaEmbeddingModel.builder()
+ .withOllamaApi(ollamaApi)
+ .withDefaultOptions(properties.getOptions())
+ .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
+ .withModelManagementOptions(new ModelManagementOptions(initProperties.getPullModelStrategy(),
+ initProperties.getTimeout(), initProperties.getMaxRetries()))
+ .build();
observationConvention.ifAvailable(embeddingModel::setObservationConvention);
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java
new file mode 100644
index 00000000000..572b1b442a2
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.ollama;
+
+import org.springframework.ai.ollama.management.PullModelStrategy;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+
+import java.time.Duration;
+
+/**
+ * Ollama initialization configuration properties.
+ *
+ * @author Thomas Vitale
+ * @since 1.0.0
+ */
+@ConfigurationProperties(OllamaInitializationProperties.CONFIG_PREFIX)
+public class OllamaInitializationProperties {
+
+ public static final String CONFIG_PREFIX = "spring.ai.ollama.init";
+
+ /**
+ * Whether to pull models at startup-time and how.
+ */
+ private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER;
+
+ /**
+ * How long to wait for a model to be pulled.
+ */
+ private Duration timeout = Duration.ofMinutes(5);
+
+ /**
+ * Maximum number of retries for the model pull operation.
+ */
+ private int maxRetries = 0;
+
+ public PullModelStrategy getPullModelStrategy() {
+ return pullModelStrategy;
+ }
+
+ public void setPullModelStrategy(PullModelStrategy pullModelStrategy) {
+ this.pullModelStrategy = pullModelStrategy;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public int getMaxRetries() {
+ return maxRetries;
+ }
+
+ public void setMaxRetries(int maxRetries) {
+ this.maxRetries = maxRetries;
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java
index b216badcf84..f4323403958 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java
@@ -1,15 +1,12 @@
package org.springframework.ai.autoconfigure.ollama;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.management.OllamaModelManager;
+import org.springframework.ai.ollama.management.PullModelStrategy;
import org.testcontainers.ollama.OllamaContainer;
-import java.io.IOException;
-
public class BaseOllamaIT {
- private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class);
-
// Toggle for running tests locally on native Ollama for a faster feedback loop.
private static final boolean useTestcontainers = true;
@@ -31,18 +28,16 @@ public class BaseOllamaIT {
* to the file ".testcontainers.properties" located in your home directory
*/
public static boolean isDisabled() {
- return true;
+ return false;
}
- public static String buildConnectionWithModel(String model) throws IOException, InterruptedException {
+ public static String buildConnectionWithModel(String model) {
var baseUrl = "http://localhost:11434";
if (useTestcontainers) {
baseUrl = ollamaContainer.getEndpoint();
-
- logger.info("Start pulling the '{}' model. The operation can take several minutes...", model);
- ollamaContainer.execInContainer("ollama", "pull", model);
- logger.info("Completed pulling the '{}' model", model);
}
+ var ollamaModelManager = new OllamaModelManager(new OllamaApi(baseUrl));
+ ollamaModelManager.pullModel(model, PullModelStrategy.WHEN_MISSING);
return baseUrl;
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java
index 8048bfc2487..10b07fb4559 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java
@@ -30,7 +30,9 @@
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatModel;
+import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
+import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -98,6 +100,23 @@ public void chatCompletionStreaming() {
});
}
+ @Test
+ public void chatCompletionWithPull() {
+ contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing")
+ .withPropertyValues("spring.ai.ollama.chat.options.model=tinyllama")
+ .run(context -> {
+ var model = "tinyllama";
+ OllamaApi ollamaApi = context.getBean(OllamaApi.class);
+ var modelManager = new OllamaModelManager(ollamaApi);
+ assertThat(modelManager.isModelAvailable(model)).isTrue();
+
+ OllamaChatModel chatModel = context.getBean(OllamaChatModel.class);
+ ChatResponse response = chatModel.call(new Prompt(userMessage));
+ assertThat(response.getResult().getOutput().getContent()).contains("Copenhagen");
+ modelManager.deleteModel(model);
+ });
+ }
+
@Test
void chatActivation() {
contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=false").run(context -> {
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java
index 9e1bd8b8dec..0ea701a6708 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java
@@ -21,9 +21,10 @@
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledIf;
+import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
+import org.springframework.ai.ollama.management.OllamaModelManager;
import org.testcontainers.junit.jupiter.Testcontainers;
-
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -67,6 +68,23 @@ public void singleTextEmbedding() {
});
}
+ @Test
+ public void embeddingWithPull() {
+ contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing")
+ .withPropertyValues("spring.ai.ollama.embedding.options.model=all-minilm")
+ .run(context -> {
+ var model = "all-minilm";
+ OllamaApi ollamaApi = context.getBean(OllamaApi.class);
+ var modelManager = new OllamaModelManager(ollamaApi);
+ assertThat(modelManager.isModelAvailable(model)).isTrue();
+
+ OllamaEmbeddingModel embeddingModel = context.getBean(OllamaEmbeddingModel.class);
+ EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World"));
+ assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
+ modelManager.deleteModel(model);
+ });
+ }
+
@Test
void embeddingActivation() {
contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=false").run(context -> {
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java
index 2167e44b4d0..b7bc4e408bf 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java
@@ -17,7 +17,6 @@
import static org.assertj.core.api.Assertions.assertThat;
-import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
@@ -36,7 +35,6 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.ollama.OllamaChatModel;
-import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -50,12 +48,12 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT {
private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class);
- private static final String MODEL_NAME = OllamaModel.LLAMA3_1.getName();
+ private static final String MODEL_NAME = "qwen2.5:3b";
static String baseUrl;
@BeforeAll
- public static void beforeAll() throws IOException, InterruptedException {
+ public static void beforeAll() {
baseUrl = buildConnectionWithModel(MODEL_NAME);
}
@@ -75,12 +73,13 @@ void functionCallTest() {
OllamaChatModel chatModel = context.getBean(OllamaChatModel.class);
UserMessage userMessage = new UserMessage(
- "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
+ "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
var promptOptions = OllamaOptions.builder()
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("CurrentWeatherService")
- .withDescription("Get the weather in location")
+ .withDescription(
+ "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
@@ -100,12 +99,14 @@ void streamingFunctionCallTest() {
OllamaChatModel chatModel = context.getBean(OllamaChatModel.class);
- UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
+ UserMessage userMessage = new UserMessage(
+ "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
var promptOptions = OllamaOptions.builder()
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("CurrentWeatherService")
- .withDescription("Get the weather in location")
+ .withDescription(
+ "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java
index 83d27c5437a..82fd7eb119d 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java
@@ -17,7 +17,6 @@
import static org.assertj.core.api.Assertions.assertThat;
-import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
@@ -55,12 +54,12 @@ public class FunctionCallbackWrapperIT extends BaseOllamaIT {
private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackWrapperIT.class);
- private static final String MODEL_NAME = OllamaModel.LLAMA3_1.getName();
+ private static final String MODEL_NAME = "qwen2.5:3b";
static String baseUrl;
@BeforeAll
- public static void beforeAll() throws IOException, InterruptedException {
+ public static void beforeAll() {
baseUrl = buildConnectionWithModel(MODEL_NAME);
}
@@ -81,7 +80,7 @@ void functionCallTest() {
OllamaChatModel chatModel = context.getBean(OllamaChatModel.class);
UserMessage userMessage = new UserMessage(
- "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
+ "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
ChatResponse response = chatModel
.call(new Prompt(List.of(userMessage), OllamaOptions.builder().withFunction("WeatherInfo").build()));
@@ -100,7 +99,7 @@ void streamFunctionCallTest() {
OllamaChatModel chatModel = context.getBean(OllamaChatModel.class);
UserMessage userMessage = new UserMessage(
- "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'WeatherInfo'");
+ "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
Flux response = chatModel
.stream(new Prompt(List.of(userMessage), OllamaOptions.builder().withFunction("WeatherInfo").build()));
@@ -126,7 +125,8 @@ void functionCallWithPortableFunctionCallingOptions() {
OllamaChatModel chatModel = context.getBean(OllamaChatModel.class);
// Test weatherFunction
- UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
+ UserMessage userMessage = new UserMessage(
+ "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
.withFunction("WeatherInfo")
@@ -148,7 +148,8 @@ public FunctionCallback weatherFunctionInfo() {
return FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("WeatherInfo")
- .withDescription("Get the weather in location")
+ .withDescription(
+ "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build();
}