Skip to content

Commit 0ff6035

Browse files
committed
Ollama: Pull models automatically at startup
* Introduce support for Ollama model auto-pull at startup time * Enhance support for Ollama model auto-pull at run time * Update documentation about integrating with Ollama and managing models * Adopt Builder pattern in Ollama Model classes for better code readability * Unify Ollama model auto-pull functionality in production and test code * Improve integration tests for Ollama with Testcontainers
1 parent e970c26 commit 0ff6035

31 files changed

+878
-327
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@
4747
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
4848
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
4949
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
50-
import org.springframework.ai.ollama.api.OllamaModelPuller;
50+
import org.springframework.ai.ollama.management.ModelManagementOptions;
51+
import org.springframework.ai.ollama.management.OllamaModelManager;
5152
import org.springframework.ai.ollama.api.OllamaOptions;
53+
import org.springframework.ai.ollama.management.PullModelStrategy;
5254
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
5355
import org.springframework.util.Assert;
5456
import org.springframework.util.CollectionUtils;
@@ -59,10 +61,9 @@
5961
/**
6062
* {@link ChatModel} implementation for {@literal Ollama}. Ollama allows developers to run
6163
* large language models and generate embeddings locally. It supports open-source models
62-
* available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>). - Llama
63-
* 2 (7B parameters, 3.8GB size) - Mistral (7B parameters, 4.1GB size) Please refer to the
64-
* <a href="https://ollama.ai/">official Ollama website</a> for the most up-to-date
65-
* information on available models.
64+
* available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>) and on
65+
* Hugging Face. Please refer to the <a href="https://ollama.ai/">official Ollama
66+
* website</a> for the most up-to-date information on available models.
6667
*
6768
* @author Christian Tzolov
6869
* @author luocongqiu
@@ -73,57 +74,33 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
7374

7475
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
7576

76-
/**
77-
* Low-level Ollama API library.
78-
*/
7977
private final OllamaApi chatApi;
8078

81-
/**
82-
* Default options to be used for all chat requests.
83-
*/
8479
private final OllamaOptions defaultOptions;
8580

86-
/**
87-
* Observation registry used for instrumentation.
88-
*/
8981
private final ObservationRegistry observationRegistry;
9082

91-
/**
92-
* Conventions to use for generating observations.
93-
*/
94-
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
95-
96-
private final OllamaModelPuller modelPuller;
97-
98-
public OllamaChatModel(OllamaApi ollamaApi) {
99-
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
100-
}
101-
102-
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
103-
this(ollamaApi, defaultOptions, null);
104-
}
83+
private final OllamaModelManager modelManager;
10584

106-
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
107-
FunctionCallbackContext functionCallbackContext) {
108-
this(ollamaApi, defaultOptions, functionCallbackContext, List.of());
109-
}
85+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
11086

11187
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
112-
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
113-
this(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks, ObservationRegistry.NOOP);
114-
}
115-
116-
public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
11788
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
118-
ObservationRegistry observationRegistry) {
89+
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
11990
super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
120-
Assert.notNull(chatApi, "ollamaApi must not be null");
91+
Assert.notNull(ollamaApi, "ollamaApi must not be null");
12192
Assert.notNull(defaultOptions, "defaultOptions must not be null");
122-
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
123-
this.chatApi = chatApi;
93+
Assert.notNull(observationRegistry, "observationRegistry must not be null");
94+
Assert.notNull(observationRegistry, "modelManagementOptions must not be null");
95+
this.chatApi = ollamaApi;
12496
this.defaultOptions = defaultOptions;
12597
this.observationRegistry = observationRegistry;
126-
this.modelPuller = new OllamaModelPuller(chatApi);
98+
this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions);
99+
initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
100+
}
101+
102+
public static Builder builder() {
103+
return new Builder();
127104
}
128105

129106
@Override
@@ -324,9 +301,9 @@ else if (message instanceof ToolResponseMessage toolMessage) {
324301
}
325302
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
326303

327-
mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
328-
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
329-
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
304+
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
305+
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
306+
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
330307
}
331308

332309
// Override the model.
@@ -353,9 +330,7 @@ else if (message instanceof ToolResponseMessage toolMessage) {
353330
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
354331
}
355332

356-
if (mergedOptions.isPullMissingModel()) {
357-
this.modelPuller.pullModel(mergedOptions.getModel(), true);
358-
}
333+
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());
359334

360335
return requestBuilder.build();
361336
}
@@ -400,6 +375,15 @@ public ChatOptions getDefaultOptions() {
400375
return OllamaOptions.fromOptions(this.defaultOptions);
401376
}
402377

378+
/**
379+
* Pull the given model into Ollama based on the specified strategy.
380+
*/
381+
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
382+
if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
383+
this.modelManager.pullModel(model, pullModelStrategy);
384+
}
385+
}
386+
403387
/**
404388
* Use the provided convention for reporting observation data
405389
* @param observationConvention The provided convention
@@ -409,4 +393,58 @@ public void setObservationConvention(ChatModelObservationConvention observationC
409393
this.observationConvention = observationConvention;
410394
}
411395

396+
public static class Builder {
397+
398+
private OllamaApi ollamaApi;
399+
400+
private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
401+
402+
private FunctionCallbackContext functionCallbackContext;
403+
404+
private List<FunctionCallback> toolFunctionCallbacks = List.of();
405+
406+
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
407+
408+
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
409+
410+
private Builder() {
411+
}
412+
413+
public Builder withOllamaApi(OllamaApi ollamaApi) {
414+
this.ollamaApi = ollamaApi;
415+
return this;
416+
}
417+
418+
public Builder withDefaultOptions(OllamaOptions defaultOptions) {
419+
this.defaultOptions = defaultOptions;
420+
return this;
421+
}
422+
423+
public Builder withFunctionCallbackContext(FunctionCallbackContext functionCallbackContext) {
424+
this.functionCallbackContext = functionCallbackContext;
425+
return this;
426+
}
427+
428+
public Builder withToolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
429+
this.toolFunctionCallbacks = toolFunctionCallbacks;
430+
return this;
431+
}
432+
433+
public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
434+
this.observationRegistry = observationRegistry;
435+
return this;
436+
}
437+
438+
public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
439+
this.modelManagementOptions = modelManagementOptions;
440+
return this;
441+
}
442+
443+
public OllamaChatModel build() {
444+
return new OllamaChatModel(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks,
445+
observationRegistry, modelManagementOptions);
446+
}
447+
448+
}
449+
412450
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.regex.Pattern;
2323

2424
import io.micrometer.observation.ObservationRegistry;
25-
import org.springframework.ai.chat.metadata.EmptyUsage;
2625
import org.springframework.ai.document.Document;
2726
import org.springframework.ai.embedding.*;
2827
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
@@ -32,24 +31,20 @@
3231
import org.springframework.ai.model.ModelOptionsUtils;
3332
import org.springframework.ai.ollama.api.OllamaApi;
3433
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
35-
import org.springframework.ai.ollama.api.OllamaModelPuller;
34+
import org.springframework.ai.ollama.management.ModelManagementOptions;
35+
import org.springframework.ai.ollama.management.OllamaModelManager;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.ai.ollama.management.PullModelStrategy;
3738
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
3839
import org.springframework.util.Assert;
3940
import org.springframework.util.StringUtils;
4041

4142
/**
42-
* {@link EmbeddingModel} implementation for {@literal Ollama}.
43-
*
44-
* Ollama allows developers to run large language models and generate embeddings locally.
45-
* It supports open-source models available on [Ollama AI
46-
* Library](https://ollama.ai/library).
47-
*
48-
* Examples of models supported: - Llama 2 (7B parameters, 3.8GB size) - Mistral (7B
49-
* parameters, 4.1GB size)
50-
*
51-
* Please refer to the <a href="https://ollama.ai/">official Ollama website</a> for the
52-
* most up-to-date information on available models.
43+
* {@link EmbeddingModel} implementation for {@literal Ollama}. Ollama allows developers
44+
* to run large language models and generate embeddings locally. It supports open-source
45+
* models available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>)
46+
* and on Hugging Face. Please refer to the <a href="https://ollama.ai/">official Ollama
47+
* website</a> for the most up-to-date information on available models.
5348
*
5449
* @author Christian Tzolov
5550
* @author Thomas Vitale
@@ -61,41 +56,31 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
6156

6257
private final OllamaApi ollamaApi;
6358

64-
/**
65-
* Default options to be used for all chat requests.
66-
*/
6759
private final OllamaOptions defaultOptions;
6860

69-
/**
70-
* Observation registry used for instrumentation.
71-
*/
7261
private final ObservationRegistry observationRegistry;
7362

74-
/**
75-
* Conventions to use for generating observations.
76-
*/
77-
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
78-
79-
private final OllamaModelPuller modelPuller;
63+
private final OllamaModelManager modelManager;
8064

81-
public OllamaEmbeddingModel(OllamaApi ollamaApi) {
82-
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
83-
}
84-
85-
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
86-
this(ollamaApi, defaultOptions, ObservationRegistry.NOOP);
87-
}
65+
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
8866

8967
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
90-
ObservationRegistry observationRegistry) {
91-
Assert.notNull(ollamaApi, "openAiApi must not be null");
68+
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
69+
Assert.notNull(ollamaApi, "ollamaApi must not be null");
9270
Assert.notNull(defaultOptions, "options must not be null");
9371
Assert.notNull(observationRegistry, "observationRegistry must not be null");
72+
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
9473

9574
this.ollamaApi = ollamaApi;
9675
this.defaultOptions = defaultOptions;
9776
this.observationRegistry = observationRegistry;
98-
this.modelPuller = new OllamaModelPuller(ollamaApi);
77+
this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
78+
79+
initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
80+
}
81+
82+
public static Builder builder() {
83+
return new Builder();
9984
}
10085

10186
@Override
@@ -153,9 +138,9 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em
153138

154139
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
155140

156-
mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
157-
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
158-
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
141+
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
142+
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
143+
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
159144
}
160145

161146
// Override the model.
@@ -164,9 +149,7 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em
164149
}
165150
String model = mergedOptions.getModel();
166151

167-
if (mergedOptions.isPullMissingModel()) {
168-
this.modelPuller.pullModel(model, true);
169-
}
152+
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());
170153

171154
return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
172155
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
@@ -176,6 +159,15 @@ private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request
176159
return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
177160
}
178161

162+
/**
163+
* Pull the given model into Ollama based on the specified strategy.
164+
*/
165+
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
166+
if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
167+
this.modelManager.pullModel(model, pullModelStrategy);
168+
}
169+
}
170+
179171
/**
180172
* Use the provided convention for reporting observation data
181173
* @param observationConvention The provided convention
@@ -216,4 +208,43 @@ public static Duration parse(String input) {
216208

217209
}
218210

211+
public static class Builder {
212+
213+
private OllamaApi ollamaApi;
214+
215+
private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
216+
217+
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
218+
219+
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
220+
221+
private Builder() {
222+
}
223+
224+
public Builder withOllamaApi(OllamaApi ollamaApi) {
225+
this.ollamaApi = ollamaApi;
226+
return this;
227+
}
228+
229+
public Builder withDefaultOptions(OllamaOptions defaultOptions) {
230+
this.defaultOptions = defaultOptions;
231+
return this;
232+
}
233+
234+
public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
235+
this.observationRegistry = observationRegistry;
236+
return this;
237+
}
238+
239+
public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
240+
this.modelManagementOptions = modelManagementOptions;
241+
return this;
242+
}
243+
244+
public OllamaEmbeddingModel build() {
245+
return new OllamaEmbeddingModel(ollamaApi, defaultOptions, observationRegistry, modelManagementOptions);
246+
}
247+
248+
}
249+
219250
}

0 commit comments

Comments
 (0)