Skip to content

Commit da8ffe3

Browse files
committed
Support Ollama APIs for model management
* Extend the OllamaApi to support listing, copying, deleting, and pulling models programmatically. * Improve setup for integration testing with Ollama and Testcontainers. Enables gh-526 Signed-off-by: Thomas Vitale <[email protected]>
1 parent fe09b4d commit da8ffe3

File tree

20 files changed

+416
-338
lines changed

20 files changed

+416
-338
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
import org.springframework.ai.observation.conventions.AiProvider;
3030
import org.springframework.boot.context.properties.bind.ConstructorBinding;
3131
import org.springframework.http.HttpHeaders;
32+
import org.springframework.http.HttpMethod;
3233
import org.springframework.http.MediaType;
34+
import org.springframework.http.ResponseEntity;
3335
import org.springframework.http.client.ClientHttpResponse;
3436
import org.springframework.util.Assert;
3537
import org.springframework.util.StreamUtils;
@@ -803,5 +805,159 @@ public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) {
803805
.body(EmbeddingResponse.class);
804806
}
805807

808+
// --------------------------------------------------------------------------
809+
// Models
810+
// --------------------------------------------------------------------------
811+
812+
@JsonInclude(Include.NON_NULL)
813+
public record Model(
814+
@JsonProperty("name") String name,
815+
@JsonProperty("model") String model,
816+
@JsonProperty("modified_at") Instant modifiedAt,
817+
@JsonProperty("size") Long size,
818+
@JsonProperty("digest") String digest,
819+
@JsonProperty("details") Details details
820+
) {
821+
@JsonInclude(Include.NON_NULL)
822+
public record Details(
823+
@JsonProperty("parent_model") String parentModel,
824+
@JsonProperty("format") String format,
825+
@JsonProperty("family") String family,
826+
@JsonProperty("families") List<String> families,
827+
@JsonProperty("parameter_size") String parameterSize,
828+
@JsonProperty("quantization_level") String quantizationLevel
829+
) {}
830+
}
831+
832+
@JsonInclude(Include.NON_NULL)
833+
public record ListModelResponse(
834+
@JsonProperty("models") List<Model> models
835+
) {}
836+
837+
/**
838+
* List models that are available locally on the machine where Ollama is running.
839+
*/
840+
public ListModelResponse listModels() {
841+
return this.restClient.get()
842+
.uri("/api/tags")
843+
.retrieve()
844+
.onStatus(this.responseErrorHandler)
845+
.body(ListModelResponse.class);
846+
}
847+
848+
@JsonInclude(Include.NON_NULL)
849+
public record ShowModelRequest(
850+
@JsonProperty("model") String model,
851+
@JsonProperty("system") String system,
852+
@JsonProperty("verbose") Boolean verbose,
853+
@JsonProperty("options") Map<String, Object> options
854+
) {
855+
public ShowModelRequest(String model) {
856+
this(model, null, null, null);
857+
}
858+
}
859+
860+
@JsonInclude(Include.NON_NULL)
861+
public record ShowModelResponse(
862+
@JsonProperty("license") String license,
863+
@JsonProperty("modelfile") String modelfile,
864+
@JsonProperty("parameters") String parameters,
865+
@JsonProperty("template") String template,
866+
@JsonProperty("system") String system,
867+
@JsonProperty("details") Model.Details details,
868+
@JsonProperty("messages") List<Message> messages,
869+
@JsonProperty("model_info") Map<String, Object> modelInfo,
870+
@JsonProperty("projector_info") Map<String, Object> projectorInfo,
871+
@JsonProperty("modified_at") Instant modifiedAt
872+
) {}
873+
874+
/**
875+
* Show information about a model available locally on the machine where Ollama is running.
876+
*/
877+
public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
878+
return this.restClient.post()
879+
.uri("/api/show")
880+
.body(showModelRequest)
881+
.retrieve()
882+
.onStatus(this.responseErrorHandler)
883+
.body(ShowModelResponse.class);
884+
}
885+
886+
@JsonInclude(Include.NON_NULL)
887+
public record CopyModelRequest(
888+
@JsonProperty("source") String source,
889+
@JsonProperty("destination") String destination
890+
) {}
891+
892+
/**
893+
* Copy a model. Creates a model with another name from an existing model.
894+
*/
895+
public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
896+
return this.restClient.post()
897+
.uri("/api/copy")
898+
.body(copyModelRequest)
899+
.retrieve()
900+
.onStatus(this.responseErrorHandler)
901+
.toBodilessEntity();
902+
}
903+
904+
@JsonInclude(Include.NON_NULL)
905+
public record DeleteModelRequest(
906+
@JsonProperty("model") String model
907+
) {}
908+
909+
/**
910+
* Delete a model and its data.
911+
*/
912+
public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
913+
return this.restClient.method(HttpMethod.DELETE)
914+
.uri("/api/delete")
915+
.body(deleteModelRequest)
916+
.retrieve()
917+
.onStatus(this.responseErrorHandler)
918+
.toBodilessEntity();
919+
}
920+
921+
@JsonInclude(Include.NON_NULL)
922+
public record PullModelRequest(
923+
@JsonProperty("model") String model,
924+
@JsonProperty("insecure") Boolean insecure,
925+
@JsonProperty("username") String username,
926+
@JsonProperty("password") String password,
927+
@JsonProperty("stream") Boolean stream
928+
) {
929+
public PullModelRequest {
930+
if (stream != null && stream) {
931+
logger.warn("Streaming when pulling models is not supported yet");
932+
}
933+
stream = false;
934+
}
935+
936+
public PullModelRequest(String model) {
937+
this(model, null, null, null, null);
938+
}
939+
}
940+
941+
@JsonInclude(Include.NON_NULL)
942+
public record ProgressResponse(
943+
@JsonProperty("status") String status,
944+
@JsonProperty("digest") String digest,
945+
@JsonProperty("total") Long total,
946+
@JsonProperty("completed") Long completed
947+
) {}
948+
949+
/**
950+
* Download a model from the Ollama library. Cancelled pulls are resumed from where they left off,
951+
* and multiple calls will share the same download progress.
952+
*/
953+
public ProgressResponse pullModel(PullModelRequest pullModelRequest) {
954+
return this.restClient.post()
955+
.uri("/api/pull")
956+
.body(pullModelRequest)
957+
.retrieve()
958+
.onStatus(this.responseErrorHandler)
959+
.body(ProgressResponse.class);
960+
}
961+
806962
}
807963
// @formatter:on
Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
package org.springframework.ai.ollama;
22

3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.ai.ollama.api.OllamaApi;
36
import org.testcontainers.ollama.OllamaContainer;
47

58
public class BaseOllamaIT {
69

10+
private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class);
11+
12+
// Toggle for running tests locally on native Ollama for a faster feedback loop.
13+
private static final boolean useTestcontainers = true;
14+
715
public static final OllamaContainer ollamaContainer;
816

917
static {
@@ -13,14 +21,34 @@ public class BaseOllamaIT {
1321

1422
/**
1523
* Change the return value to false in order to run multiple Ollama IT tests locally
16-
* reusing the same container image Also add the entry
24+
* reusing the same container image.
25+
*
26+
* Also, add the entry
1727
*
1828
* testcontainers.reuse.enable=true
1929
*
20-
* to the file .testcontainers.properties located in your home directory
30+
* to the file ".testcontainers.properties" located in your home directory
2131
*/
2232
public static boolean isDisabled() {
2333
return true;
2434
}
2535

36+
public static OllamaApi buildOllamaApiWithModel(String model) {
37+
var baseUrl = "http://localhost:11434";
38+
if (useTestcontainers) {
39+
baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
40+
}
41+
var ollamaApi = new OllamaApi(baseUrl);
42+
43+
ensureModelIsPresent(ollamaApi, model);
44+
45+
return ollamaApi;
46+
}
47+
48+
public static void ensureModelIsPresent(OllamaApi ollamaApi, String model) {
49+
logger.info("Start pulling the '{}' model. The operation can take several minutes...", model);
50+
ollamaApi.pullModel(new OllamaApi.PullModelRequest(model));
51+
logger.info("Completed pulling the '{}' model", model);
52+
}
53+
2654
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package org.springframework.ai.ollama;
1717

18-
import org.junit.jupiter.api.BeforeAll;
1918
import org.junit.jupiter.api.Disabled;
2019
import org.junit.jupiter.api.Test;
2120
import org.junit.jupiter.api.condition.DisabledIf;
@@ -40,7 +39,6 @@
4039
import org.testcontainers.junit.jupiter.Testcontainers;
4140
import reactor.core.publisher.Flux;
4241

43-
import java.io.IOException;
4442
import java.util.ArrayList;
4543
import java.util.List;
4644
import java.util.stream.Collectors;
@@ -54,25 +52,13 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
5452

5553
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class);
5654

57-
private static final String MODEL = OllamaModel.MISTRAL.getName();
58-
59-
static String baseUrl = "http://localhost:11434";
60-
61-
@BeforeAll
62-
public static void beforeAll() throws IOException, InterruptedException {
63-
logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ...");
64-
ollamaContainer.execInContainer("ollama", "pull", MODEL);
65-
logger.info(MODEL + " pulling competed!");
66-
67-
baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
68-
}
55+
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
6956

7057
@Autowired
7158
ChatModel chatModel;
7259

7360
@Test
7461
void functionCallTest() {
75-
7662
UserMessage userMessage = new UserMessage(
7763
"What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");
7864

@@ -97,7 +83,6 @@ void functionCallTest() {
9783
@Disabled("Ollama API does not support streaming function calls yet")
9884
@Test
9985
void streamFunctionCallTest() {
100-
10186
UserMessage userMessage = new UserMessage(
10287
"What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");
10388

@@ -132,7 +117,7 @@ static class Config {
132117

133118
@Bean
134119
public OllamaApi ollamaApi() {
135-
return new OllamaApi(baseUrl);
120+
return buildOllamaApiWithModel(MODEL);
136121
}
137122

138123
@Bean

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
*/
1616
package org.springframework.ai.ollama;
1717

18-
import org.apache.commons.logging.Log;
19-
import org.apache.commons.logging.LogFactory;
20-
import org.junit.jupiter.api.BeforeAll;
2118
import org.junit.jupiter.api.Test;
2219
import org.junit.jupiter.api.condition.DisabledIf;
2320
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -43,7 +40,6 @@
4340
import org.springframework.core.convert.support.DefaultConversionService;
4441
import org.testcontainers.junit.jupiter.Testcontainers;
4542

46-
import java.io.IOException;
4743
import java.util.Arrays;
4844
import java.util.List;
4945
import java.util.Map;
@@ -58,19 +54,6 @@ class OllamaChatModelIT extends BaseOllamaIT {
5854

5955
private static final String MODEL = OllamaModel.MISTRAL.getName();
6056

61-
private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class);
62-
63-
static String baseUrl = "http://localhost:11434";
64-
65-
@BeforeAll
66-
public static void beforeAll() throws IOException, InterruptedException {
67-
logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ...");
68-
ollamaContainer.execInContainer("ollama", "pull", MODEL);
69-
logger.info(MODEL + " pulling competed!");
70-
71-
baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
72-
}
73-
7457
@Autowired
7558
private OllamaChatModel chatModel;
7659

@@ -98,12 +81,10 @@ void roleTest() {
9881

9982
response = chatModel.call(new Prompt(List.of(userMessage, systemMessage), ollamaOptions));
10083
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
101-
10284
}
10385

10486
@Test
10587
void testMessageHistory() {
106-
10788
Message systemMessage = new SystemPromptTemplate("""
10889
You are a helpful AI assistant. Your name is {name}.
10990
You are an AI assistant that helps people find information.
@@ -117,13 +98,13 @@ void testMessageHistory() {
11798
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
11899

119100
ChatResponse response = chatModel.call(prompt);
120-
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Bonny");
101+
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard");
121102

122103
var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(),
123104
new UserMessage("Repeat the last assistant message.")));
124105
response = chatModel.call(promptWithMessageHistory);
125106

126-
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Bonny");
107+
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard");
127108
}
128109

129110
@Test
@@ -175,15 +156,13 @@ void mapOutputConvert() {
175156

176157
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
177158
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
178-
179159
}
180160

181161
record ActorsFilmsRecord(String actor, List<String> movies) {
182162
}
183163

184164
@Test
185165
void beanOutputConverterRecords() {
186-
187166
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
188167

189168
String format = outputConverter.getFormat();
@@ -202,7 +181,6 @@ void beanOutputConverterRecords() {
202181

203182
@Test
204183
void beanStreamOutputConverterRecords() {
205-
206184
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
207185

208186
String format = outputConverter.getFormat();
@@ -235,7 +213,7 @@ public static class TestConfiguration {
235213

236214
@Bean
237215
public OllamaApi ollamaApi() {
238-
return new OllamaApi(baseUrl);
216+
return buildOllamaApiWithModel(MODEL);
239217
}
240218

241219
@Bean

0 commit comments

Comments
 (0)