Skip to content

Commit 8e603e6

Browse files
committed
Support Ollama APIs for model management
* Extend the OllamaApi to support listing, copying, deleting, and pulling models programmatically. * Improve setup for integration testing with Ollama and Testcontainers. Enables gh-526 Signed-off-by: Thomas Vitale <[email protected]>
1 parent 6d38c85 commit 8e603e6

File tree

20 files changed

+437
-355
lines changed

20 files changed

+437
-355
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
import org.springframework.ai.observation.conventions.AiProvider;
3030
import org.springframework.boot.context.properties.bind.ConstructorBinding;
3131
import org.springframework.http.HttpHeaders;
32+
import org.springframework.http.HttpMethod;
3233
import org.springframework.http.MediaType;
34+
import org.springframework.http.ResponseEntity;
3335
import org.springframework.http.client.ClientHttpResponse;
3436
import org.springframework.util.Assert;
3537
import org.springframework.util.StreamUtils;
@@ -807,5 +809,159 @@ public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) {
807809
.body(EmbeddingResponse.class);
808810
}
809811

812+
// --------------------------------------------------------------------------
813+
// Models
814+
// --------------------------------------------------------------------------
815+
816+
@JsonInclude(Include.NON_NULL)
817+
public record Model(
818+
@JsonProperty("name") String name,
819+
@JsonProperty("model") String model,
820+
@JsonProperty("modified_at") Instant modifiedAt,
821+
@JsonProperty("size") Long size,
822+
@JsonProperty("digest") String digest,
823+
@JsonProperty("details") Details details
824+
) {
825+
@JsonInclude(Include.NON_NULL)
826+
public record Details(
827+
@JsonProperty("parent_model") String parentModel,
828+
@JsonProperty("format") String format,
829+
@JsonProperty("family") String family,
830+
@JsonProperty("families") List<String> families,
831+
@JsonProperty("parameter_size") String parameterSize,
832+
@JsonProperty("quantization_level") String quantizationLevel
833+
) {}
834+
}
835+
836+
@JsonInclude(Include.NON_NULL)
837+
public record ListModelResponse(
838+
@JsonProperty("models") List<Model> models
839+
) {}
840+
841+
/**
842+
* List models that are available locally on the machine where Ollama is running.
843+
*/
844+
public ListModelResponse listModels() {
845+
return this.restClient.get()
846+
.uri("/api/tags")
847+
.retrieve()
848+
.onStatus(this.responseErrorHandler)
849+
.body(ListModelResponse.class);
850+
}
851+
852+
@JsonInclude(Include.NON_NULL)
853+
public record ShowModelRequest(
854+
@JsonProperty("model") String model,
855+
@JsonProperty("system") String system,
856+
@JsonProperty("verbose") Boolean verbose,
857+
@JsonProperty("options") Map<String, Object> options
858+
) {
859+
public ShowModelRequest(String model) {
860+
this(model, null, null, null);
861+
}
862+
}
863+
864+
@JsonInclude(Include.NON_NULL)
865+
public record ShowModelResponse(
866+
@JsonProperty("license") String license,
867+
@JsonProperty("modelfile") String modelfile,
868+
@JsonProperty("parameters") String parameters,
869+
@JsonProperty("template") String template,
870+
@JsonProperty("system") String system,
871+
@JsonProperty("details") Model.Details details,
872+
@JsonProperty("messages") List<Message> messages,
873+
@JsonProperty("model_info") Map<String, Object> modelInfo,
874+
@JsonProperty("projector_info") Map<String, Object> projectorInfo,
875+
@JsonProperty("modified_at") Instant modifiedAt
876+
) {}
877+
878+
/**
879+
* Show information about a model available locally on the machine where Ollama is running.
880+
*/
881+
public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
882+
return this.restClient.post()
883+
.uri("/api/show")
884+
.body(showModelRequest)
885+
.retrieve()
886+
.onStatus(this.responseErrorHandler)
887+
.body(ShowModelResponse.class);
888+
}
889+
890+
@JsonInclude(Include.NON_NULL)
891+
public record CopyModelRequest(
892+
@JsonProperty("source") String source,
893+
@JsonProperty("destination") String destination
894+
) {}
895+
896+
/**
897+
* Copy a model. Creates a model with another name from an existing model.
898+
*/
899+
public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
900+
return this.restClient.post()
901+
.uri("/api/copy")
902+
.body(copyModelRequest)
903+
.retrieve()
904+
.onStatus(this.responseErrorHandler)
905+
.toBodilessEntity();
906+
}
907+
908+
@JsonInclude(Include.NON_NULL)
909+
public record DeleteModelRequest(
910+
@JsonProperty("model") String model
911+
) {}
912+
913+
/**
914+
* Delete a model and its data.
915+
*/
916+
public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
917+
return this.restClient.method(HttpMethod.DELETE)
918+
.uri("/api/delete")
919+
.body(deleteModelRequest)
920+
.retrieve()
921+
.onStatus(this.responseErrorHandler)
922+
.toBodilessEntity();
923+
}
924+
925+
@JsonInclude(Include.NON_NULL)
926+
public record PullModelRequest(
927+
@JsonProperty("model") String model,
928+
@JsonProperty("insecure") Boolean insecure,
929+
@JsonProperty("username") String username,
930+
@JsonProperty("password") String password,
931+
@JsonProperty("stream") Boolean stream
932+
) {
933+
public PullModelRequest {
934+
if (stream != null && stream) {
935+
logger.warn("Streaming when pulling models is not supported yet");
936+
}
937+
stream = false;
938+
}
939+
940+
public PullModelRequest(String model) {
941+
this(model, null, null, null, null);
942+
}
943+
}
944+
945+
@JsonInclude(Include.NON_NULL)
946+
public record ProgressResponse(
947+
@JsonProperty("status") String status,
948+
@JsonProperty("digest") String digest,
949+
@JsonProperty("total") Long total,
950+
@JsonProperty("completed") Long completed
951+
) {}
952+
953+
/**
954+
* Download a model from the Ollama library. Cancelled pulls are resumed from where they left off,
955+
* and multiple calls will share the same download progress.
956+
*/
957+
public ProgressResponse pullModel(PullModelRequest pullModelRequest) {
958+
return this.restClient.post()
959+
.uri("/api/pull")
960+
.body(pullModelRequest)
961+
.retrieve()
962+
.onStatus(this.responseErrorHandler)
963+
.body(ProgressResponse.class);
964+
}
965+
810966
}
811967
// @formatter:on
Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
package org.springframework.ai.ollama;
22

3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.ai.ollama.api.OllamaApi;
36
import org.testcontainers.ollama.OllamaContainer;
47

58
public class BaseOllamaIT {
69

10+
private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class);
11+
12+
// Toggle for running tests locally on native Ollama for a faster feedback loop.
13+
private static final boolean useTestcontainers = false;
14+
715
public static final OllamaContainer ollamaContainer;
816

917
static {
@@ -13,14 +21,34 @@ public class BaseOllamaIT {
1321

1422
/**
1523
* Change the return value to false in order to run multiple Ollama IT tests locally
16-
* reusing the same container image Also add the entry
24+
* reusing the same container image.
25+
*
26+
* Also, add the entry
1727
*
1828
* testcontainers.reuse.enable=true
1929
*
20-
* to the file .testcontainers.properties located in your home directory
30+
* to the file ".testcontainers.properties" located in your home directory
2131
*/
2232
public static boolean isDisabled() {
23-
return true;
33+
return false;
34+
}
35+
36+
public static OllamaApi buildOllamaApiWithModel(String model) {
37+
var baseUrl = "http://localhost:11434";
38+
if (useTestcontainers) {
39+
baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
40+
}
41+
var ollamaApi = new OllamaApi(baseUrl);
42+
43+
ensureModelIsPresent(ollamaApi, model);
44+
45+
return ollamaApi;
46+
}
47+
48+
public static void ensureModelIsPresent(OllamaApi ollamaApi, String model) {
49+
logger.info("Start pulling the '{}' model. The operation can take several minutes...", model);
50+
ollamaApi.pullModel(new OllamaApi.PullModelRequest(model));
51+
logger.info("Completed pulling the '{}' model", model);
2452
}
2553

2654
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package org.springframework.ai.ollama;
1717

18-
import org.junit.jupiter.api.BeforeAll;
1918
import org.junit.jupiter.api.Disabled;
2019
import org.junit.jupiter.api.Test;
2120
import org.junit.jupiter.api.condition.DisabledIf;
@@ -40,7 +39,6 @@
4039
import org.testcontainers.junit.jupiter.Testcontainers;
4140
import reactor.core.publisher.Flux;
4241

43-
import java.io.IOException;
4442
import java.util.ArrayList;
4543
import java.util.List;
4644
import java.util.stream.Collectors;
@@ -54,25 +52,13 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
5452

5553
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class);
5654

57-
private static final String MODEL = OllamaModel.MISTRAL.getName();
58-
59-
static String baseUrl = "http://localhost:11434";
60-
61-
@BeforeAll
62-
public static void beforeAll() throws IOException, InterruptedException {
63-
logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ...");
64-
ollamaContainer.execInContainer("ollama", "pull", MODEL);
65-
logger.info(MODEL + " pulling competed!");
66-
67-
baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
68-
}
55+
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
6956

7057
@Autowired
7158
ChatModel chatModel;
7259

7360
@Test
7461
void functionCallTest() {
75-
7662
UserMessage userMessage = new UserMessage(
7763
"What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");
7864

@@ -97,7 +83,6 @@ void functionCallTest() {
9783
@Disabled("Ollama API does not support streaming function calls yet")
9884
@Test
9985
void streamFunctionCallTest() {
100-
10186
UserMessage userMessage = new UserMessage(
10287
"What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");
10388

@@ -132,7 +117,7 @@ static class Config {
132117

133118
@Bean
134119
public OllamaApi ollamaApi() {
135-
return new OllamaApi(baseUrl);
120+
return buildOllamaApiWithModel(MODEL);
136121
}
137122

138123
@Bean

0 commit comments

Comments
 (0)