diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 03603951ef3..4ddef5c5d66 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -47,6 +47,7 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall; import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction; +import org.springframework.ai.ollama.api.OllamaModelPuller; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.metadata.OllamaChatUsage; import org.springframework.util.Assert; @@ -92,6 +93,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + private final OllamaModelPuller modelPuller; + public OllamaChatModel(OllamaApi ollamaApi) { this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL)); } @@ -120,10 +123,12 @@ public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions, this.chatApi = chatApi; this.defaultOptions = defaultOptions; this.observationRegistry = observationRegistry; + this.modelPuller = new OllamaModelPuller(chatApi); } @Override public ChatResponse call(Prompt prompt) { + OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -319,6 +324,11 @@ else if (message instanceof ToolResponseMessage toolMessage) { } OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class); + mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel()); + if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) { + mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel()); + } + // Override the model. if (!StringUtils.hasText(mergedOptions.getModel())) { throw new IllegalArgumentException("Model is not set!"); @@ -343,6 +353,10 @@ else if (message instanceof ToolResponseMessage toolMessage) { requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest)); } + if (mergedOptions.isPullMissingModel()) { + this.modelPuller.pullModel(mergedOptions.getModel(), true); + } + return requestBuilder.build(); } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index 0984034d23e..e918cc9e662 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -32,6 +32,7 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; +import org.springframework.ai.ollama.api.OllamaModelPuller; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage; import org.springframework.util.Assert; @@ -75,6 +76,8 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel { */ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + private final OllamaModelPuller modelPuller; + public OllamaEmbeddingModel(OllamaApi ollamaApi) { this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL)); } @@ -92,6 +95,7 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, this.ollamaApi = ollamaApi; this.defaultOptions = defaultOptions; this.observationRegistry = observationRegistry; + this.modelPuller = new OllamaModelPuller(ollamaApi); } @Override @@ -149,12 +153,21 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List inputContent, Em OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class); + mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel()); + if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) { + mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel()); + } + // Override the model. if (!StringUtils.hasText(mergedOptions.getModel())) { throw new IllegalArgumentException("Model is not set!"); } String model = mergedOptions.getModel(); + if (mergedOptions.isPullMissingModel()) { + this.modelPuller.pullModel(model, true); + } + return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()), OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate()); } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModelPuller.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModelPuller.java new file mode 100644 index 00000000000..28138ce39b5 --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModelPuller.java @@ -0,0 +1,90 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.ollama.api; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest; +import org.springframework.ai.ollama.api.OllamaApi.ListModelResponse; +import org.springframework.ai.ollama.api.OllamaApi.PullModelRequest; +import org.springframework.http.HttpStatus; +import org.springframework.util.CollectionUtils; + +/** + * Helper class that allow to check if a model is available locally and pull it if not. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public class OllamaModelPuller { + + private final Logger logger = LoggerFactory.getLogger(OllamaModelPuller.class); + + private OllamaApi ollamaApi; + + public OllamaModelPuller(OllamaApi ollamaApi) { + this.ollamaApi = ollamaApi; + } + + public boolean isModelAvailable(String modelName) { + ListModelResponse modelsResponse = ollamaApi.listModels(); + if (!CollectionUtils.isEmpty(modelsResponse.models())) { + return modelsResponse.models().stream().anyMatch(m -> m.name().equals(modelName)); + } + return false; + } + + public boolean deleteModel(String modelName) { + logger.info("Delete model: {}", modelName); + if (!isModelAvailable(modelName)) { + logger.info("Model: {} not found!", modelName); + return false; + } + return this.ollamaApi.deleteModel(new DeleteModelRequest(modelName)).getStatusCode().equals(HttpStatus.OK); + } + + public String pullModel(String modelName, boolean reTry) { + String status = ""; + do { + logger.info("Start Pulling model: {}", modelName); + var progress = this.ollamaApi.pullModel(new PullModelRequest(modelName)); + status = progress.status(); + logger.info("Pulling model: {} - Status: {}", modelName, status); + try { + Thread.sleep(5000); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + } + while (reTry && !status.equals("success")); + return status; + } + + public static void main(String[] args) { + + var utils = new OllamaModelPuller(new OllamaApi()); + + System.out.println(utils.isModelAvailable("orca-mini:latest")); + + String model = "hf.co/bartowski/Llama-3.2-3B-Instruct-GGUF:Q8_0"; + + if (!utils.isModelAvailable(model)) { + utils.pullModel(model, true); + } + } + +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 8adee039fd7..d6f2f85cdac 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -304,6 +304,9 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed @JsonIgnore private Map toolContext; + @JsonIgnore + private boolean pullMissingModel; + public static OllamaOptions builder() { return new OllamaOptions(); } @@ -516,6 +519,11 @@ public OllamaOptions withToolContext(Map toolContext) { return this; } + public OllamaOptions withPullMissingModel(boolean pullMissingModel) { + this.pullMissingModel = pullMissingModel; + return this; + } + // ------------------- // Getters and Setters // ------------------- @@ -856,6 +864,14 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + public Boolean isPullMissingModel() { + return this.pullMissingModel; + } + + public void setPullMissingModel(boolean pullMissingModel) { + this.pullMissingModel = pullMissingModel; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. @@ -926,7 +942,8 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .withFunctions(fromOptions.getFunctions()) .withProxyToolCalls(fromOptions.getProxyToolCalls()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withToolContext(fromOptions.getToolContext()); + .withToolContext(fromOptions.getToolContext()) + .withPullMissingModel(fromOptions.isPullMissingModel()); } // @formatter:on @@ -956,7 +973,8 @@ public boolean equals(Object o) { && Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop) && Objects.equals(functionCallbacks, that.functionCallbacks) && Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions) - && Objects.equals(toolContext, that.toolContext); + && Objects.equals(toolContext, that.toolContext) + && Objects.equals(pullMissingModel, that.pullMissingModel); } @Override @@ -967,7 +985,7 @@ public int hashCode() { this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls, - this.toolContext); + this.toolContext, this.pullMissingModel); } } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java index 9618f07e67b..ac78a718dd3 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java @@ -10,7 +10,7 @@ public class BaseOllamaIT { private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class); // Toggle for running tests locally on native Ollama for a faster feedback loop. - private static final boolean useTestcontainers = true; + private static final boolean useTestcontainers = false; public static final OllamaContainer ollamaContainer; @@ -30,7 +30,7 @@ public class BaseOllamaIT { * to the file ".testcontainers.properties" located in your home directory */ public static boolean isDisabled() { - return true; + return false; } public static OllamaApi buildOllamaApiWithModel(String model) { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 0a1f0590394..a622dbcc17b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -32,6 +33,7 @@ import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaModelPuller; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -56,6 +58,26 @@ class OllamaChatModelIT extends BaseOllamaIT { @Autowired private OllamaChatModel chatModel; + @Autowired + private OllamaApi ollamaApi; + + @Test + void autoPullModelTest() { + var puller = new OllamaModelPuller(ollamaApi); + puller.deleteModel("tinyllama"); + + assertThat(puller.isModelAvailable("tinyllama")).isFalse(); + + String joke = ChatClient.create(chatModel) + .prompt("Tell me a joke") + .options(OllamaOptions.builder().withModel("tinyllama").withPullMissingModel(true).build()) + .call() + .content(); + + assertThat(joke).isNotEmpty(); + assertThat(puller.isModelAvailable("tinyllamaf")).isFalse(); + } + @Test void roleTest() { Message systemMessage = new SystemPromptTemplate(""" diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java index 49c56d43d99..3813017d177 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java @@ -20,7 +20,9 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest; import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaModelPuller; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -42,6 +44,9 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT { @Autowired private OllamaEmbeddingModel embeddingModel; + @Autowired + private OllamaApi ollamaApi; + @Test void embeddings() { assertThat(embeddingModel).isNotNull(); @@ -59,6 +64,37 @@ void embeddings() { assertThat(embeddingModel.dimensions()).isEqualTo(768); } + @Test + void autoPullModel() { + assertThat(embeddingModel).isNotNull(); + + var puller = new OllamaModelPuller(ollamaApi); + puller.deleteModel("all-minilm:latest"); + + assertThat(puller.isModelAvailable("all-minilm")).isFalse(); + + EmbeddingResponse embeddingResponse = embeddingModel + .call(new EmbeddingRequest(List.of("Hello World", "Something else"), + OllamaOptions.builder() + .withModel("all-minilm:latest") + .withPullMissingModel(true) + .withTruncate(false) + .build())); + + assertThat(puller.isModelAvailable("all-minilm:latest")).isTrue(); + + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("all-minilm:latest"); + assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); + + assertThat(embeddingModel.dimensions()).isEqualTo(768); + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index a5f573da7e0..00a27a222bf 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -11,8 +11,6 @@ Check the xref:_openai_api_compatibility[OpenAI API compatibility] section to le You first need to run Ollama on your local machine. Refer to the official Ollama project link:https://github.com/ollama/ollama[README] to get started running models on your local machine. -NOTE: Running `ollama pull mistral` will download a 4.1GB model artifact. - === Add Repositories and BOM Spring AI artifacts are published in Spring Milestone and Snapshot repositories. @@ -66,6 +64,7 @@ Here are the advanced request parameter for the Ollama chat model: | spring.ai.ollama.chat.enabled | Enable Ollama chat model. | true | spring.ai.ollama.chat.options.model | The name of the https://github.com/ollama/ollama?tab=readme-ov-file#model-library[supported model] to use. | mistral +| spring.ai.ollama.chat.options.pull-missing-model | Automatically pull missing models from Ollama repository | false | spring.ai.ollama.chat.options.format | The format to return a response in. Currently, the only accepted value is `json` | - | spring.ai.ollama.chat.options.keep_alive | Controls how long the model will stay loaded into memory following the request | 5m |==== @@ -133,6 +132,27 @@ ChatResponse response = chatModel.call( TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java[OllamaOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. +=== Auto-pulling Models + +The `pullMissingModel` option allows you to automatically download and use models that are not currently available on your local Ollama instance. +This feature is particularly useful when working with different models or when deploying your application to new environments. + +To enable auto-pulling of missing models, you can set the `pullMissingModel` option to `true` in your `OllamaOptions`: + +[source,java] +---- +OllamaOptions options = OllamaOptions.builder() + .withModel("all-minilm:latest") + .withPullMissingModel(true) + .build(); +---- + +You can also configure this option using the following property: `spring.ai.ollama.chat.options.pull-missing-model=true` + +When `pullMissingModel` is set to `true`, the system will attempt to download the specified model if it's not already available locally. This process may take some time depending on the size of the model and your internet connection speed. + +CAUTION: Be aware that enabling this option may lead to unexpected delays in your application if it needs to download large model files. It's recommended to pre-download commonly used models in production environments. + == Function Calling You can register custom Java functions with the `OllamaChatModel` and have the Ollama model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc index eb43d0228f5..6729a63a1ea 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc @@ -53,19 +53,20 @@ It includes the Ollama request (advanced) parameters such as the `model`, `keep- Here are the advanced request parameter for the Ollama embedding model: -[cols="3,6,1"] +[cols="4,5,1"] |==== | Property | Description | Default | spring.ai.ollama.embedding.enabled | Enables the Ollama embedding model auto-configuration. | true | spring.ai.ollama.embedding.options.model | The name of the https://github.com/ollama/ollama?tab=readme-ov-file#model-library[supported model] to use. You can use dedicated https://ollama.com/search?c=embedding[Embedding Model] types | mistral +| spring.ai.ollama.embedding.options.pull-missing-model | Automatically pull missing models from Ollama repository | false | spring.ai.ollama.embedding.options.keep_alive | Controls how long the model will stay loaded into memory following the request | 5m | spring.ai.ollama.embedding.options.truncate | Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. | true |==== The remaining `options` properties are based on the link:https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values[Ollama Valid Parameters and Values] and link:https://github.com/ollama/ollama/blob/main/api/types.go[Ollama Types]. The default values are based on: link:https://github.com/ollama/ollama/blob/b538dc3858014f94b099730a592751a5454cab0a/api/types.go#L364[Ollama type defaults]. -[cols="3,6,1"] +[cols="4,5,1"] |==== | Property | Description | Default | spring.ai.ollama.embedding.options.numa | Whether to use NUMA. | false @@ -123,6 +124,30 @@ EmbeddingResponse embeddingResponse = embeddingModel.call( .build()); ---- +=== Auto-pulling Models + +The `pullMissingModel` option allows you to automatically download and use models that are not currently available on your local Ollama instance. +This feature is particularly useful when working with different models or when deploying your application to new environments. + +To enable auto-pulling of missing models, you can set the `pullMissingModel` option to `true` in your `OllamaOptions`: + +[source,java] +---- +EmbeddingResponse embeddingResponse = embeddingModel + .call(new EmbeddingRequest(List.of("Hello World", "Something else"), + OllamaOptions.builder() + .withModel("all-minilm:latest") + .withPullMissingModel(true) + .withTruncate(false) + .build())); +---- + +You can also configure this option using the following property: `spring.ai.ollama.embedding.options.pull-missing-model=true` + +When `pullMissingModel` is set to `true`, the system will attempt to download the specified model if it's not already available locally. This process may take some time depending on the size of the model and your internet connection speed. + +CAUTION: Be aware that enabling this option may lead to unexpected delays in your application if it needs to download large model files. It's recommended to pre-download commonly used models in production environments. + == Sample Controller This will create a `EmbeddingModel` implementation that you can inject into your class.