Skip to content

Commit d5bc9c9

Browse files
ThomasVitaletzolov
authored andcommitted
Consolidate Ollama auto-pull logic
Consolidate the Ollama auto-pull logic at startup time, supporting the auto-pull for the default models specified via configuration properties and for optional models specified for initialization. Signed-off-by: Thomas Vitale <[email protected]>
1 parent 1cadc49 commit d5bc9c9

File tree

13 files changed

+140
-155
lines changed

13 files changed

+140
-155
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
9797
this.defaultOptions = defaultOptions;
9898
this.observationRegistry = observationRegistry;
9999
this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions);
100-
initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
100+
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
101101
}
102102

103103
public static Builder builder() {
@@ -302,11 +302,6 @@ else if (message instanceof ToolResponseMessage toolMessage) {
302302
}
303303
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
304304

305-
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
306-
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
307-
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
308-
}
309-
310305
// Override the model.
311306
if (!StringUtils.hasText(mergedOptions.getModel())) {
312307
throw new IllegalArgumentException("Model is not set!");
@@ -331,8 +326,6 @@ else if (message instanceof ToolResponseMessage toolMessage) {
331326
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
332327
}
333328

334-
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());
335-
336329
return requestBuilder.build();
337330
}
338331

@@ -379,7 +372,7 @@ public ChatOptions getDefaultOptions() {
379372
/**
380373
* Pull the given model into Ollama based on the specified strategy.
381374
*/
382-
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
375+
private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
383376
if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
384377
this.modelManager.pullModel(model, pullModelStrategy);
385378
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
7777
this.observationRegistry = observationRegistry;
7878
this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
7979

80-
initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
80+
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
8181
}
8282

8383
public static Builder builder() {
@@ -139,19 +139,12 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em
139139

140140
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
141141

142-
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
143-
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
144-
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
145-
}
146-
147142
// Override the model.
148143
if (!StringUtils.hasText(mergedOptions.getModel())) {
149144
throw new IllegalArgumentException("Model is not set!");
150145
}
151146
String model = mergedOptions.getModel();
152147

153-
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());
154-
155148
return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
156149
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
157150
}
@@ -163,7 +156,7 @@ private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request
163156
/**
164157
* Pull the given model into Ollama based on the specified strategy.
165158
*/
166-
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
159+
private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
167160
if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
168161
this.modelManager.pullModel(model, pullModelStrategy);
169162
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.springframework.ai.model.ModelOptionsUtils;
2929
import org.springframework.ai.model.function.FunctionCallback;
3030
import org.springframework.ai.model.function.FunctionCallingOptions;
31-
import org.springframework.ai.ollama.management.PullModelStrategy;
3231
import org.springframework.boot.context.properties.NestedConfigurationProperty;
3332
import org.springframework.util.Assert;
3433

@@ -303,12 +302,6 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
303302
@JsonIgnore
304303
private Map<String, Object> toolContext;
305304

306-
/**
307-
* Strategy for pulling models at run-time.
308-
*/
309-
@JsonIgnore
310-
private PullModelStrategy pullModelStrategy;
311-
312305
public static OllamaOptions builder() {
313306
return new OllamaOptions();
314307
}
@@ -521,11 +514,6 @@ public OllamaOptions withToolContext(Map<String, Object> toolContext) {
521514
return this;
522515
}
523516

524-
public OllamaOptions withPullModelStrategy(PullModelStrategy pullModelStrategy) {
525-
this.pullModelStrategy = pullModelStrategy;
526-
return this;
527-
}
528-
529517
// -------------------
530518
// Getters and Setters
531519
// -------------------
@@ -866,14 +854,6 @@ public void setToolContext(Map<String, Object> toolContext) {
866854
this.toolContext = toolContext;
867855
}
868856

869-
public PullModelStrategy getPullModelStrategy() {
870-
return this.pullModelStrategy;
871-
}
872-
873-
public void setPullModelStrategy(PullModelStrategy pullModelStrategy) {
874-
this.pullModelStrategy = pullModelStrategy;
875-
}
876-
877857
/**
878858
* Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs.
879859
* @return The {@link Map} of key/value pairs.
@@ -944,8 +924,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
944924
.withFunctions(fromOptions.getFunctions())
945925
.withProxyToolCalls(fromOptions.getProxyToolCalls())
946926
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
947-
.withToolContext(fromOptions.getToolContext())
948-
.withPullModelStrategy(fromOptions.getPullModelStrategy());
927+
.withToolContext(fromOptions.getToolContext());
949928
}
950929
// @formatter:on
951930

@@ -975,8 +954,7 @@ public boolean equals(Object o) {
975954
&& Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop)
976955
&& Objects.equals(functionCallbacks, that.functionCallbacks)
977956
&& Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions)
978-
&& Objects.equals(toolContext, that.toolContext)
979-
&& Objects.equals(pullModelStrategy, that.pullModelStrategy);
957+
&& Objects.equals(toolContext, that.toolContext);
980958
}
981959

982960
@Override
@@ -987,7 +965,7 @@ public int hashCode() {
987965
this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
988966
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
989967
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls,
990-
this.toolContext, this.pullModelStrategy);
968+
this.toolContext);
991969
}
992970

993971
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,49 @@
2626
*/
2727
public record ModelManagementOptions(PullModelStrategy pullModelStrategy, List<String> additionalModels,
2828
Duration timeout, Integer maxRetries) {
29+
2930
public static ModelManagementOptions defaults() {
3031
return new ModelManagementOptions(PullModelStrategy.NEVER, List.of(), Duration.ofMinutes(5), 0);
3132
}
33+
34+
public static Builder builder() {
35+
return new Builder();
36+
}
37+
38+
public static class Builder {
39+
40+
private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER;
41+
42+
private List<String> additionalModels = List.of();
43+
44+
private Duration timeout = Duration.ofMinutes(5);
45+
46+
private Integer maxRetries = 0;
47+
48+
public Builder withPullModelStrategy(PullModelStrategy pullModelStrategy) {
49+
this.pullModelStrategy = pullModelStrategy;
50+
return this;
51+
}
52+
53+
public Builder withAdditionalModels(List<String> additionalModels) {
54+
this.additionalModels = additionalModels;
55+
return this;
56+
}
57+
58+
public Builder withTimeout(Duration timeout) {
59+
this.timeout = timeout;
60+
return this;
61+
}
62+
63+
public Builder withMaxRetries(Integer maxRetries) {
64+
this.maxRetries = maxRetries;
65+
return this;
66+
}
67+
68+
public ModelManagementOptions build() {
69+
return new ModelManagementOptions(pullModelStrategy, additionalModels, timeout, maxRetries);
70+
}
71+
72+
}
73+
3274
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.converter.MapOutputConverter;
3434
import org.springframework.ai.ollama.api.OllamaApi;
3535
import org.springframework.ai.ollama.api.OllamaModel;
36+
import org.springframework.ai.ollama.management.ModelManagementOptions;
3637
import org.springframework.ai.ollama.management.OllamaModelManager;
3738
import org.springframework.ai.ollama.api.OllamaOptions;
3839
import org.springframework.ai.ollama.management.PullModelStrategy;
@@ -56,6 +57,8 @@ class OllamaChatModelIT extends BaseOllamaIT {
5657

5758
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
5859

60+
private static final String ADDITIONAL_MODEL = "tinyllama";
61+
5962
@Autowired
6063
private OllamaChatModel chatModel;
6164

@@ -65,23 +68,17 @@ class OllamaChatModelIT extends BaseOllamaIT {
6568
@Test
6669
void autoPullModelTest() {
6770
var modelManager = new OllamaModelManager(ollamaApi);
68-
var model = "tinyllama";
69-
modelManager.deleteModel(model);
70-
assertThat(modelManager.isModelAvailable(model)).isFalse();
71+
assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue();
7172

7273
String joke = ChatClient.create(chatModel)
7374
.prompt("Tell me a joke")
74-
.options(OllamaOptions.builder()
75-
.withModel(model)
76-
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
77-
.build())
75+
.options(OllamaOptions.builder().withModel(ADDITIONAL_MODEL).build())
7876
.call()
7977
.content();
8078

8179
assertThat(joke).isNotEmpty();
82-
assertThat(modelManager.isModelAvailable(model)).isTrue();
8380

84-
modelManager.deleteModel(model);
81+
modelManager.deleteModel(ADDITIONAL_MODEL);
8582
}
8683

8784
@Test
@@ -249,6 +246,10 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
249246
return OllamaChatModel.builder()
250247
.withOllamaApi(ollamaApi)
251248
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
249+
.withModelManagementOptions(ModelManagementOptions.builder()
250+
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
251+
.withAdditionalModels(List.of(ADDITIONAL_MODEL))
252+
.build())
252253
.build();
253254
}
254255

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.springframework.ai.embedding.EmbeddingResponse;
2222
import org.springframework.ai.ollama.api.OllamaApi;
2323
import org.springframework.ai.ollama.api.OllamaModel;
24+
import org.springframework.ai.ollama.management.ModelManagementOptions;
2425
import org.springframework.ai.ollama.management.OllamaModelManager;
2526
import org.springframework.ai.ollama.api.OllamaOptions;
2627
import org.springframework.ai.ollama.management.PullModelStrategy;
@@ -41,6 +42,8 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT {
4142

4243
private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
4344

45+
private static final String ADDITIONAL_MODEL = "all-minilm";
46+
4447
@Autowired
4548
private OllamaEmbeddingModel embeddingModel;
4649

@@ -65,36 +68,29 @@ void embeddings() {
6568
}
6669

6770
@Test
68-
void autoPullModel() {
71+
void autoPullModelAtStartupTime() {
6972
var model = "all-minilm";
7073
assertThat(embeddingModel).isNotNull();
7174

7275
var modelManager = new OllamaModelManager(ollamaApi);
73-
modelManager.deleteModel(model);
74-
assertThat(modelManager.isModelAvailable(model)).isFalse();
76+
assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue();
7577

7678
EmbeddingResponse embeddingResponse = embeddingModel
7779
.call(new EmbeddingRequest(List.of("Hello World", "Something else"),
78-
OllamaOptions.builder()
79-
.withModel(model)
80-
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
81-
.withTruncate(false)
82-
.build()));
83-
84-
assertThat(modelManager.isModelAvailable(model)).isTrue();
80+
OllamaOptions.builder().withModel(model).withTruncate(false).build()));
8581

8682
assertThat(embeddingResponse.getResults()).hasSize(2);
8783
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
8884
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
8985
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
9086
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
91-
assertThat(embeddingResponse.getMetadata().getModel()).contains(model);
87+
assertThat(embeddingResponse.getMetadata().getModel()).contains(ADDITIONAL_MODEL);
9288
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
9389
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
9490

9591
assertThat(embeddingModel.dimensions()).isEqualTo(768);
9692

97-
modelManager.deleteModel(model);
93+
modelManager.deleteModel(ADDITIONAL_MODEL);
9894
}
9995

10096
@SpringBootConfiguration
@@ -110,6 +106,10 @@ public OllamaEmbeddingModel ollamaEmbedding(OllamaApi ollamaApi) {
110106
return OllamaEmbeddingModel.builder()
111107
.withOllamaApi(ollamaApi)
112108
.withDefaultOptions(OllamaOptions.create().withModel(MODEL))
109+
.withModelManagementOptions(ModelManagementOptions.builder()
110+
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
111+
.withAdditionalModels(List.of(ADDITIONAL_MODEL))
112+
.build())
113113
.build();
114114
}
115115

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
*/
2323
public class OllamaImage {
2424

25-
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.13");
25+
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.14");
2626

2727
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@
1515
*/
1616
package org.springframework.ai.ollama.api;
1717

18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.io.IOException;
21+
import java.time.Duration;
22+
1823
import org.junit.jupiter.api.BeforeAll;
1924
import org.junit.jupiter.api.Test;
2025
import org.junit.jupiter.api.condition.DisabledIf;
2126
import org.springframework.ai.ollama.BaseOllamaIT;
2227
import org.springframework.http.HttpStatus;
2328
import org.testcontainers.junit.jupiter.Testcontainers;
2429

25-
import java.io.IOException;
26-
import java.time.Duration;
27-
28-
import static org.assertj.core.api.Assertions.assertThat;
29-
3030
/**
3131
* Integration tests for the Ollama APIs to manage models.
3232
*
@@ -36,7 +36,7 @@
3636
@DisabledIf("isDisabled")
3737
public class OllamaApiModelsIT extends BaseOllamaIT {
3838

39-
private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
39+
private static final String MODEL = "all-minilm";
4040

4141
static OllamaApi ollamaApi;
4242

@@ -60,7 +60,7 @@ public void showModel() {
6060
var showModelResponse = ollamaApi.showModel(showModelRequest);
6161

6262
assertThat(showModelResponse).isNotNull();
63-
assertThat(showModelResponse.details().family()).isEqualTo("nomic-bert");
63+
assertThat(showModelResponse.details().family()).isEqualTo("bert");
6464
}
6565

6666
@Test

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// :YES: image::yes.svg[width=16]
44
// :NO: image::no.svg[width=12]
5-
// [%autowidth]
5+
66

77
This table compares various Chat Models supported by Spring AI, detailing their capabilities:
88

@@ -39,6 +39,5 @@ This table compares various Chat Models supported by Spring AI, detailing their
3939
| xref::api/chat/bedrock/bedrock-llama.adoc[Amazon Bedrock/Llama] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12]
4040
| xref::api/chat/bedrock/bedrock-titan.adoc[Amazon Bedrock/Titan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12]
4141
| xref::api/chat/bedrock/bedrock-anthropic3.adoc[Amazon Bedrock/Anthropic 3] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12]
42-
4342
|====
4443

0 commit comments

Comments
 (0)