Skip to content

Commit 2a9f9c8

Browse files
committed
Ollama: add model auto-pull feature
- Introduce internal OllamaModelPuller helper for managing model availability - Add pullMissingModel option to OllamaOptions - Implement auto-pull functionality in OllamaChatModel and OllamaEmbeddingModel - Update tests to cover new auto-pull feature - Add reference documentation Resolves #526
1 parent f461bd6 commit 2a9f9c8

File tree

9 files changed

+247
-9
lines changed

9 files changed

+247
-9
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
4848
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
4949
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
50+
import org.springframework.ai.ollama.api.OllamaModelPuller;
5051
import org.springframework.ai.ollama.api.OllamaOptions;
5152
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
5253
import org.springframework.util.Assert;
@@ -92,6 +93,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
9293
*/
9394
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
9495

96+
private final OllamaModelPuller modelPuller;
97+
9598
public OllamaChatModel(OllamaApi ollamaApi) {
9699
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
97100
}
@@ -120,10 +123,12 @@ public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
120123
this.chatApi = chatApi;
121124
this.defaultOptions = defaultOptions;
122125
this.observationRegistry = observationRegistry;
126+
this.modelPuller = new OllamaModelPuller(chatApi);
123127
}
124128

125129
@Override
126130
public ChatResponse call(Prompt prompt) {
131+
127132
OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false);
128133

129134
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -319,6 +324,11 @@ else if (message instanceof ToolResponseMessage toolMessage) {
319324
}
320325
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
321326

327+
mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
328+
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
329+
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
330+
}
331+
322332
// Override the model.
323333
if (!StringUtils.hasText(mergedOptions.getModel())) {
324334
throw new IllegalArgumentException("Model is not set!");
@@ -343,6 +353,10 @@ else if (message instanceof ToolResponseMessage toolMessage) {
343353
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
344354
}
345355

356+
if (mergedOptions.isPullMissingModel()) {
357+
this.modelPuller.pullModel(mergedOptions.getModel(), true);
358+
}
359+
346360
return requestBuilder.build();
347361
}
348362

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.ai.model.ModelOptionsUtils;
3333
import org.springframework.ai.ollama.api.OllamaApi;
3434
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
35+
import org.springframework.ai.ollama.api.OllamaModelPuller;
3536
import org.springframework.ai.ollama.api.OllamaOptions;
3637
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
3738
import org.springframework.util.Assert;
@@ -75,6 +76,8 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
7576
*/
7677
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
7778

79+
private final OllamaModelPuller modelPuller;
80+
7881
public OllamaEmbeddingModel(OllamaApi ollamaApi) {
7982
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
8083
}
@@ -92,6 +95,7 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
9295
this.ollamaApi = ollamaApi;
9396
this.defaultOptions = defaultOptions;
9497
this.observationRegistry = observationRegistry;
98+
this.modelPuller = new OllamaModelPuller(ollamaApi);
9599
}
96100

97101
@Override
@@ -149,12 +153,21 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em
149153

150154
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
151155

156+
mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
157+
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
158+
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
159+
}
160+
152161
// Override the model.
153162
if (!StringUtils.hasText(mergedOptions.getModel())) {
154163
throw new IllegalArgumentException("Model is not set!");
155164
}
156165
String model = mergedOptions.getModel();
157166

167+
if (mergedOptions.isPullMissingModel()) {
168+
this.modelPuller.pullModel(model, true);
169+
}
170+
158171
return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
159172
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
160173
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.ollama.api;
17+
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest;
21+
import org.springframework.ai.ollama.api.OllamaApi.ListModelResponse;
22+
import org.springframework.ai.ollama.api.OllamaApi.PullModelRequest;
23+
import org.springframework.http.HttpStatus;
24+
import org.springframework.util.CollectionUtils;
25+
26+
/**
27+
* Helper class that allow to check if a model is available locally and pull it if not.
28+
*
29+
* @author Christian Tzolov
30+
* @since 1.0.0
31+
*/
32+
public class OllamaModelPuller {
33+
34+
private final Logger logger = LoggerFactory.getLogger(OllamaModelPuller.class);
35+
36+
private OllamaApi ollamaApi;
37+
38+
public OllamaModelPuller(OllamaApi ollamaApi) {
39+
this.ollamaApi = ollamaApi;
40+
}
41+
42+
public boolean isModelAvailable(String modelName) {
43+
ListModelResponse modelsResponse = ollamaApi.listModels();
44+
if (!CollectionUtils.isEmpty(modelsResponse.models())) {
45+
return modelsResponse.models().stream().anyMatch(m -> m.name().equals(modelName));
46+
}
47+
return false;
48+
}
49+
50+
public boolean deleteModel(String modelName) {
51+
logger.info("Delete model: {}", modelName);
52+
if (!isModelAvailable(modelName)) {
53+
logger.info("Model: {} not found!", modelName);
54+
return false;
55+
}
56+
return this.ollamaApi.deleteModel(new DeleteModelRequest(modelName)).getStatusCode().equals(HttpStatus.OK);
57+
}
58+
59+
public String pullModel(String modelName, boolean reTry) {
60+
String status = "";
61+
do {
62+
logger.info("Start Pulling model: {}", modelName);
63+
var progress = this.ollamaApi.pullModel(new PullModelRequest(modelName));
64+
status = progress.status();
65+
logger.info("Pulling model: {} - Status: {}", modelName, status);
66+
try {
67+
Thread.sleep(5000);
68+
}
69+
catch (InterruptedException e) {
70+
e.printStackTrace();
71+
}
72+
}
73+
while (reTry && !status.equals("success"));
74+
return status;
75+
}
76+
77+
public static void main(String[] args) {
78+
79+
var utils = new OllamaModelPuller(new OllamaApi());
80+
81+
System.out.println(utils.isModelAvailable("orca-mini:latest"));
82+
83+
String model = "hf.co/bartowski/Llama-3.2-3B-Instruct-GGUF:Q8_0";
84+
85+
if (!utils.isModelAvailable(model)) {
86+
utils.pullModel(model, true);
87+
}
88+
}
89+
90+
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
304304
@JsonIgnore
305305
private Map<String, Object> toolContext;
306306

307+
@JsonIgnore
308+
private boolean pullMissingModel;
309+
307310
public static OllamaOptions builder() {
308311
return new OllamaOptions();
309312
}
@@ -516,6 +519,11 @@ public OllamaOptions withToolContext(Map<String, Object> toolContext) {
516519
return this;
517520
}
518521

522+
public OllamaOptions withPullMissingModel(boolean pullMissingModel) {
523+
this.pullMissingModel = pullMissingModel;
524+
return this;
525+
}
526+
519527
// -------------------
520528
// Getters and Setters
521529
// -------------------
@@ -856,6 +864,14 @@ public void setToolContext(Map<String, Object> toolContext) {
856864
this.toolContext = toolContext;
857865
}
858866

867+
public Boolean isPullMissingModel() {
868+
return this.pullMissingModel;
869+
}
870+
871+
public void setPullMissingModel(boolean pullMissingModel) {
872+
this.pullMissingModel = pullMissingModel;
873+
}
874+
859875
/**
860876
* Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs.
861877
* @return The {@link Map} of key/value pairs.
@@ -926,7 +942,8 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
926942
.withFunctions(fromOptions.getFunctions())
927943
.withProxyToolCalls(fromOptions.getProxyToolCalls())
928944
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
929-
.withToolContext(fromOptions.getToolContext());
945+
.withToolContext(fromOptions.getToolContext())
946+
.withPullMissingModel(fromOptions.isPullMissingModel());
930947
}
931948
// @formatter:on
932949

@@ -956,7 +973,8 @@ public boolean equals(Object o) {
956973
&& Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop)
957974
&& Objects.equals(functionCallbacks, that.functionCallbacks)
958975
&& Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions)
959-
&& Objects.equals(toolContext, that.toolContext);
976+
&& Objects.equals(toolContext, that.toolContext)
977+
&& Objects.equals(pullMissingModel, that.pullMissingModel);
960978
}
961979

962980
@Override
@@ -967,7 +985,7 @@ public int hashCode() {
967985
this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
968986
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
969987
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls,
970-
this.toolContext);
988+
this.toolContext, this.pullMissingModel);
971989
}
972990

973991
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class BaseOllamaIT {
1010
private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class);
1111

1212
// Toggle for running tests locally on native Ollama for a faster feedback loop.
13-
private static final boolean useTestcontainers = true;
13+
private static final boolean useTestcontainers = false;
1414

1515
public static final OllamaContainer ollamaContainer;
1616

@@ -30,7 +30,7 @@ public class BaseOllamaIT {
3030
* to the file ".testcontainers.properties" located in your home directory
3131
*/
3232
public static boolean isDisabled() {
33-
return true;
33+
return false;
3434
}
3535

3636
public static OllamaApi buildOllamaApiWithModel(String model) {

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import org.junit.jupiter.api.Test;
1919
import org.junit.jupiter.api.condition.DisabledIf;
20+
import org.springframework.ai.chat.client.ChatClient;
2021
import org.springframework.ai.chat.messages.AssistantMessage;
2122
import org.springframework.ai.chat.messages.Message;
2223
import org.springframework.ai.chat.messages.UserMessage;
@@ -32,6 +33,7 @@
3233
import org.springframework.ai.converter.MapOutputConverter;
3334
import org.springframework.ai.ollama.api.OllamaApi;
3435
import org.springframework.ai.ollama.api.OllamaModel;
36+
import org.springframework.ai.ollama.api.OllamaModelPuller;
3537
import org.springframework.ai.ollama.api.OllamaOptions;
3638
import org.springframework.beans.factory.annotation.Autowired;
3739
import org.springframework.boot.SpringBootConfiguration;
@@ -56,6 +58,26 @@ class OllamaChatModelIT extends BaseOllamaIT {
5658
@Autowired
5759
private OllamaChatModel chatModel;
5860

61+
@Autowired
62+
private OllamaApi ollamaApi;
63+
64+
@Test
65+
void autoPullModelTest() {
66+
var puller = new OllamaModelPuller(ollamaApi);
67+
puller.deleteModel("tinyllama");
68+
69+
assertThat(puller.isModelAvailable("tinyllama")).isFalse();
70+
71+
String joke = ChatClient.create(chatModel)
72+
.prompt("Tell me a joke")
73+
.options(OllamaOptions.builder().withModel("tinyllama").withPullMissingModel(true).build())
74+
.call()
75+
.content();
76+
77+
assertThat(joke).isNotEmpty();
78+
assertThat(puller.isModelAvailable("tinyllamaf")).isFalse();
79+
}
80+
5981
@Test
6082
void roleTest() {
6183
Message systemMessage = new SystemPromptTemplate("""

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import org.springframework.ai.embedding.EmbeddingRequest;
2121
import org.springframework.ai.embedding.EmbeddingResponse;
2222
import org.springframework.ai.ollama.api.OllamaApi;
23+
import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest;
2324
import org.springframework.ai.ollama.api.OllamaModel;
25+
import org.springframework.ai.ollama.api.OllamaModelPuller;
2426
import org.springframework.ai.ollama.api.OllamaOptions;
2527
import org.springframework.beans.factory.annotation.Autowired;
2628
import org.springframework.boot.SpringBootConfiguration;
@@ -42,6 +44,9 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT {
4244
@Autowired
4345
private OllamaEmbeddingModel embeddingModel;
4446

47+
@Autowired
48+
private OllamaApi ollamaApi;
49+
4550
@Test
4651
void embeddings() {
4752
assertThat(embeddingModel).isNotNull();
@@ -59,6 +64,37 @@ void embeddings() {
5964
assertThat(embeddingModel.dimensions()).isEqualTo(768);
6065
}
6166

67+
@Test
68+
void autoPullModel() {
69+
assertThat(embeddingModel).isNotNull();
70+
71+
var puller = new OllamaModelPuller(ollamaApi);
72+
puller.deleteModel("all-minilm:latest");
73+
74+
assertThat(puller.isModelAvailable("all-minilm")).isFalse();
75+
76+
EmbeddingResponse embeddingResponse = embeddingModel
77+
.call(new EmbeddingRequest(List.of("Hello World", "Something else"),
78+
OllamaOptions.builder()
79+
.withModel("all-minilm:latest")
80+
.withPullMissingModel(true)
81+
.withTruncate(false)
82+
.build()));
83+
84+
assertThat(puller.isModelAvailable("all-minilm:latest")).isTrue();
85+
86+
assertThat(embeddingResponse.getResults()).hasSize(2);
87+
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
88+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
89+
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
90+
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
91+
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("all-minilm:latest");
92+
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
93+
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
94+
95+
assertThat(embeddingModel.dimensions()).isEqualTo(768);
96+
}
97+
6298
@SpringBootConfiguration
6399
public static class TestConfiguration {
64100

0 commit comments

Comments
 (0)