diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index b20e206bc8c..7cac43455fd 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -29,7 +29,9 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.http.client.ClientHttpResponse; import org.springframework.util.Assert; import org.springframework.util.StreamUtils; @@ -807,5 +809,159 @@ public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) { .body(EmbeddingResponse.class); } + // -------------------------------------------------------------------------- + // Models + // -------------------------------------------------------------------------- + + @JsonInclude(Include.NON_NULL) + public record Model( + @JsonProperty("name") String name, + @JsonProperty("model") String model, + @JsonProperty("modified_at") Instant modifiedAt, + @JsonProperty("size") Long size, + @JsonProperty("digest") String digest, + @JsonProperty("details") Details details + ) { + @JsonInclude(Include.NON_NULL) + public record Details( + @JsonProperty("parent_model") String parentModel, + @JsonProperty("format") String format, + @JsonProperty("family") String family, + @JsonProperty("families") List families, + @JsonProperty("parameter_size") String parameterSize, + @JsonProperty("quantization_level") String quantizationLevel + ) {} + } + + @JsonInclude(Include.NON_NULL) + public record ListModelResponse( + @JsonProperty("models") List models + ) {} + + /** + * List models that are available locally on the machine where Ollama is running. + */ + public ListModelResponse listModels() { + return this.restClient.get() + .uri("/api/tags") + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ListModelResponse.class); + } + + @JsonInclude(Include.NON_NULL) + public record ShowModelRequest( + @JsonProperty("model") String model, + @JsonProperty("system") String system, + @JsonProperty("verbose") Boolean verbose, + @JsonProperty("options") Map options + ) { + public ShowModelRequest(String model) { + this(model, null, null, null); + } + } + + @JsonInclude(Include.NON_NULL) + public record ShowModelResponse( + @JsonProperty("license") String license, + @JsonProperty("modelfile") String modelfile, + @JsonProperty("parameters") String parameters, + @JsonProperty("template") String template, + @JsonProperty("system") String system, + @JsonProperty("details") Model.Details details, + @JsonProperty("messages") List messages, + @JsonProperty("model_info") Map modelInfo, + @JsonProperty("projector_info") Map projectorInfo, + @JsonProperty("modified_at") Instant modifiedAt + ) {} + + /** + * Show information about a model available locally on the machine where Ollama is running. + */ + public ShowModelResponse showModel(ShowModelRequest showModelRequest) { + return this.restClient.post() + .uri("/api/show") + .body(showModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ShowModelResponse.class); + } + + @JsonInclude(Include.NON_NULL) + public record CopyModelRequest( + @JsonProperty("source") String source, + @JsonProperty("destination") String destination + ) {} + + /** + * Copy a model. Creates a model with another name from an existing model. + */ + public ResponseEntity copyModel(CopyModelRequest copyModelRequest) { + return this.restClient.post() + .uri("/api/copy") + .body(copyModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .toBodilessEntity(); + } + + @JsonInclude(Include.NON_NULL) + public record DeleteModelRequest( + @JsonProperty("model") String model + ) {} + + /** + * Delete a model and its data. + */ + public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) { + return this.restClient.method(HttpMethod.DELETE) + .uri("/api/delete") + .body(deleteModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .toBodilessEntity(); + } + + @JsonInclude(Include.NON_NULL) + public record PullModelRequest( + @JsonProperty("model") String model, + @JsonProperty("insecure") Boolean insecure, + @JsonProperty("username") String username, + @JsonProperty("password") String password, + @JsonProperty("stream") Boolean stream + ) { + public PullModelRequest { + if (stream != null && stream) { + logger.warn("Streaming when pulling models is not supported yet"); + } + stream = false; + } + + public PullModelRequest(String model) { + this(model, null, null, null, null); + } + } + + @JsonInclude(Include.NON_NULL) + public record ProgressResponse( + @JsonProperty("status") String status, + @JsonProperty("digest") String digest, + @JsonProperty("total") Long total, + @JsonProperty("completed") Long completed + ) {} + + /** + * Download a model from the Ollama library. Cancelled pulls are resumed from where they left off, + * and multiple calls will share the same download progress. + */ + public ProgressResponse pullModel(PullModelRequest pullModelRequest) { + return this.restClient.post() + .uri("/api/pull") + .body(pullModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ProgressResponse.class); + } + } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java index eecdfe4c91e..e78e970d040 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java @@ -1,9 +1,17 @@ package org.springframework.ai.ollama; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.ollama.api.OllamaApi; import org.testcontainers.ollama.OllamaContainer; public class BaseOllamaIT { + private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class); + + // Toggle for running tests locally on native Ollama for a faster feedback loop. + private static final boolean useTestcontainers = false; + public static final OllamaContainer ollamaContainer; static { @@ -13,14 +21,34 @@ public class BaseOllamaIT { /** * Change the return value to false in order to run multiple Ollama IT tests locally - * reusing the same container image Also add the entry + * reusing the same container image. + * + * Also, add the entry * * testcontainers.reuse.enable=true * - * to the file .testcontainers.properties located in your home directory + * to the file ".testcontainers.properties" located in your home directory */ public static boolean isDisabled() { - return true; + return false; + } + + public static OllamaApi buildOllamaApiWithModel(String model) { + var baseUrl = "http://localhost:11434"; + if (useTestcontainers) { + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + } + var ollamaApi = new OllamaApi(baseUrl); + + ensureModelIsPresent(ollamaApi, model); + + return ollamaApi; + } + + public static void ensureModelIsPresent(OllamaApi ollamaApi, String model) { + logger.info("Start pulling the '{}' model. The operation can take several minutes...", model); + ollamaApi.pullModel(new OllamaApi.PullModelRequest(model)); + logger.info("Completed pulling the '{}' model", model); } } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index eee175c5497..91c8bbd3754 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -15,7 +15,6 @@ */ package org.springframework.ai.ollama; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; @@ -40,7 +39,6 @@ import org.testcontainers.junit.jupiter.Testcontainers; import reactor.core.publisher.Flux; -import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -54,25 +52,13 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT { private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class); - private static final String MODEL = OllamaModel.MISTRAL.getName(); - - static String baseUrl = "http://localhost:11434"; - - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL); - logger.info(MODEL + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); - } + private static final String MODEL = OllamaModel.LLAMA3_2.getName(); @Autowired ChatModel chatModel; @Test void functionCallTest() { - UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius."); @@ -97,7 +83,6 @@ void functionCallTest() { @Disabled("Ollama API does not support streaming function calls yet") @Test void streamFunctionCallTest() { - UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius."); @@ -132,7 +117,7 @@ static class Config { @Bean public OllamaApi ollamaApi() { - return new OllamaApi(baseUrl); + return buildOllamaApiWithModel(MODEL); } @Bean diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index be403b35030..0a1f0590394 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -15,9 +15,6 @@ */ package org.springframework.ai.ollama; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.springframework.ai.chat.messages.AssistantMessage; @@ -43,8 +40,6 @@ import org.springframework.core.convert.support.DefaultConversionService; import org.testcontainers.junit.jupiter.Testcontainers; -import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -56,20 +51,7 @@ @DisabledIf("isDisabled") class OllamaChatModelIT extends BaseOllamaIT { - private static final String MODEL = OllamaModel.MISTRAL.getName(); - - private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class); - - static String baseUrl = "http://localhost:11434"; - - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL); - logger.info(MODEL + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); - } + private static final String MODEL = OllamaModel.LLAMA3_2.getName(); @Autowired private OllamaChatModel chatModel; @@ -88,7 +70,7 @@ void roleTest() { // portable/generic options var portableOptions = ChatOptionsBuilder.builder().withTemperature(0.7).build(); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage), portableOptions); + Prompt prompt = new Prompt(List.of(systemMessage, userMessage), portableOptions); ChatResponse response = chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); @@ -96,14 +78,12 @@ void roleTest() { // ollama specific options var ollamaOptions = new OllamaOptions().withLowVRAM(true); - response = chatModel.call(new Prompt(List.of(userMessage, systemMessage), ollamaOptions)); + response = chatModel.call(new Prompt(List.of(systemMessage, userMessage), ollamaOptions)); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); - } @Test void testMessageHistory() { - Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. You are an AI assistant that helps people find information. @@ -114,16 +94,16 @@ void testMessageHistory() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); ChatResponse response = chatModel.call(prompt); - assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Bonny"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); - var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), - new UserMessage("Repeat the last assistant message."))); + var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Hello"), response.getResult().getOutput(), + new UserMessage("Tell me just the names of those pirates."))); response = chatModel.call(promptWithMessageHistory); - assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Bonny"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); } @Test @@ -163,19 +143,20 @@ void mapOutputConvert() { String format = outputConverter.getFormat(); String template = """ - Remove Markdown code blocks from the output. - Provide me a List of {subject} + For each letter in the RGB color scheme, tell me what it stands for. + Example: R -> Red. {format} """; - PromptTemplate promptTemplate = new PromptTemplate(template, - Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); Generation generation = chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); - assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); - + assertThat(result).isNotNull(); + assertThat((String) result.get("R")).containsIgnoringCase("red"); + assertThat((String) result.get("G")).containsIgnoringCase("green"); + assertThat((String) result.get("B")).containsIgnoringCase("blue"); } record ActorsFilmsRecord(String actor, List movies) { @@ -183,12 +164,11 @@ record ActorsFilmsRecord(String actor, List movies) { @Test void beanOutputConverterRecords() { - BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); String format = outputConverter.getFormat(); String template = """ - Generate the filmography of 5 movies for Tom Hanks. + Consider the filmography of Tom Hanks and tell me 5 of his movies. {format} """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); @@ -202,14 +182,12 @@ void beanOutputConverterRecords() { @Test void beanStreamOutputConverterRecords() { - BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); String format = outputConverter.getFormat(); String template = """ - Generate the filmography of 5 movies for Tom Hanks. + Consider the filmography of Tom Hanks and tell me 5 of his movies. {format} - Remove Markdown code blocks from the output. """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); @@ -235,7 +213,7 @@ public static class TestConfiguration { @Bean public OllamaApi ollamaApi() { - return new OllamaApi(baseUrl); + return buildOllamaApiWithModel(MODEL); } @Bean diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 4dffc7d2fb8..8536d61a716 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -15,11 +15,10 @@ */ package org.springframework.ai.ollama; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; @@ -32,11 +31,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.ClassPathResource; import org.springframework.util.MimeTypeUtils; -import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; -import java.io.IOException; import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -47,42 +43,25 @@ @DisabledIf("isDisabled") class OllamaChatModelMultimodalIT extends BaseOllamaIT { - private static final String MODEL = OllamaModel.MOONDREAM.getName(); - - private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class); - - @Container - static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.DEFAULT_IMAGE); - - static String baseUrl = "http://localhost:11434"; - - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL); - logger.info(MODEL + " pulling competed!"); + private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelMultimodalIT.class); - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); - } + private static final String MODEL = OllamaModel.MOONDREAM.getName(); @Autowired private OllamaChatModel chatModel; @Test - void unsupportedMediaType() throws IOException { - + void unsupportedMediaType() { var imageData = new ClassPathResource("/norway.webp"); var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt(List.of(userMessage)))); - } @Test - void multiModalityTest() throws IOException { - + void multiModalityTest() { var imageData = new ClassPathResource("/test.png"); var userMessage = new UserMessage("Explain what do you see on this picture?", @@ -91,7 +70,7 @@ void multiModalityTest() throws IOException { var response = chatModel.call(new Prompt(List.of(userMessage))); logger.info(response.getResult().getOutput().getContent()); - assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket"); + assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple"); } @SpringBootConfiguration @@ -99,7 +78,7 @@ public static class TestConfiguration { @Bean public OllamaApi ollamaApi() { - return new OllamaApi(baseUrl); + return buildOllamaApiWithModel(MODEL); } @Bean diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index 82c6031cf87..327418a4afe 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -34,7 +34,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.testcontainers.junit.jupiter.Testcontainers; import reactor.core.publisher.Flux; import java.util.List; @@ -53,6 +52,8 @@ @DisabledIf("isDisabled") public class OllamaChatModelObservationIT extends BaseOllamaIT { + private static final String MODEL = OllamaModel.LLAMA3_2.getName(); + @Autowired TestObservationRegistry observationRegistry; @@ -67,7 +68,7 @@ void beforeEach() { @Test void observationForChatOperation() { var options = OllamaOptions.builder() - .withModel(OllamaModel.MISTRAL.getName()) + .withModel(MODEL) .withFrequencyPenalty(0.0) .withNumPredict(2048) .withPresencePenalty(0.0) @@ -91,7 +92,7 @@ void observationForChatOperation() { @Test void observationForStreamingChatOperation() { var options = OllamaOptions.builder() - .withModel(OllamaModel.MISTRAL.getName()) + .withModel(MODEL) .withFrequencyPenalty(0.0) .withNumPredict(2048) .withPresencePenalty(0.0) @@ -128,11 +129,11 @@ private void validate(ChatResponseMetadata responseMetadata) { .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() - .hasContextualNameEqualTo("chat " + OllamaModel.MISTRAL.getName()) + .hasContextualNameEqualTo("chat " + MODEL) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.CHAT.value()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.OLLAMA.value()) - .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), OllamaModel.MISTRAL.getName()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), MODEL) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") @@ -164,7 +165,7 @@ public TestObservationRegistry observationRegistry() { @Bean public OllamaApi openAiApi() { - return new OllamaApi(); + return buildOllamaApiWithModel(MODEL); } @Bean diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java index d311fe1aff7..49c56d43d99 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java @@ -15,20 +15,12 @@ */ package org.springframework.ai.ollama; -import static org.assertj.core.api.Assertions.assertThat; - -import java.io.IOException; -import java.util.List; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaApiIT; +import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -36,25 +28,16 @@ import org.springframework.context.annotation.Bean; import org.testcontainers.junit.jupiter.Testcontainers; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest @DisabledIf("isDisabled") @Testcontainers class OllamaEmbeddingModelIT extends BaseOllamaIT { - private static final String MODEL = "mxbai-embed-large"; - - private static final Log logger = LogFactory.getLog(OllamaApiIT.class); - - static String baseUrl = "http://localhost:11434"; - - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL); - logger.info(MODEL + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); - } + private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName(); @Autowired private OllamaEmbeddingModel embeddingModel; @@ -73,7 +56,7 @@ void embeddings() { assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(embeddingModel.dimensions()).isEqualTo(768); } @SpringBootConfiguration @@ -81,7 +64,7 @@ public static class TestConfiguration { @Bean public OllamaApi ollamaApi() { - return new OllamaApi(baseUrl); + return buildOllamaApiWithModel(MODEL); } @Bean diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java index 581d5d74e07..6f43b2a17bb 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java @@ -18,9 +18,6 @@ import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.springframework.ai.embedding.EmbeddingRequest; @@ -32,7 +29,6 @@ import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaApiIT; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.beans.factory.annotation.Autowired; @@ -40,7 +36,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import java.io.IOException; import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -54,8 +49,6 @@ @DisabledIf("isDisabled") public class OllamaEmbeddingModelObservationIT extends BaseOllamaIT { - private static final Log logger = LogFactory.getLog(OllamaApiIT.class); - private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName(); @Autowired @@ -64,17 +57,6 @@ public class OllamaEmbeddingModelObservationIT extends BaseOllamaIT { @Autowired OllamaEmbeddingModel embeddingModel; - static String baseUrl = "http://localhost:11434"; - - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL); - logger.info(MODEL + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); - } - @Test void observationForEmbeddingOperation() { var options = OllamaOptions.builder().withModel(OllamaModel.NOMIC_EMBED_TEXT.getName()).build(); @@ -117,7 +99,7 @@ public TestObservationRegistry observationRegistry() { @Bean public OllamaApi openAiApi() { - return new OllamaApi(baseUrl); + return buildOllamaApiWithModel(MODEL); } @Bean diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 3a0f3859125..0b57cb16e0b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -22,6 +22,6 @@ */ public class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.9"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.13"); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index 427ae8f913b..f9ed57882f3 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -18,8 +18,6 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.springframework.ai.ollama.BaseOllamaIT; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.ChatResponse; @@ -46,27 +44,17 @@ @DisabledIf("isDisabled") public class OllamaApiIT extends BaseOllamaIT { - private static final String MODEL = OllamaModel.ORCA_MINI.getName(); - - private static final Logger logger = LoggerFactory.getLogger(OllamaApiIT.class); + private static final String MODEL = OllamaModel.LLAMA3_2.getName(); static OllamaApi ollamaApi; - static String baseUrl = "http://localhost:11434"; - @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL); - logger.info(MODEL + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); - ollamaApi = new OllamaApi(baseUrl); + ollamaApi = buildOllamaApiWithModel(MODEL); } @Test public void generation() { - var request = GenerateRequest .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") .withModel(MODEL) @@ -78,13 +66,12 @@ public void generation() { System.out.println(response); assertThat(response).isNotNull(); - assertThat(response.model()).isEqualTo(response.model()); + assertThat(response.model()).contains(MODEL); assertThat(response.response()).contains("Sofia"); } @Test public void chat() { - var request = ChatRequest.builder(MODEL) .withStream(false) .withMessages(List.of( @@ -103,7 +90,7 @@ public void chat() { System.out.println(response); assertThat(response).isNotNull(); - assertThat(response.model()).isEqualTo(response.model()); + assertThat(response.model()).contains(MODEL); assertThat(response.done()).isTrue(); assertThat(response.message().role()).isEqualTo(Role.ASSISTANT); assertThat(response.message().content()).contains("Sofia"); @@ -111,7 +98,6 @@ public void chat() { @Test public void streamingChat() { - var request = ChatRequest.builder(MODEL) .withStream(true) .withMessages(List.of(Message.builder(Role.USER) @@ -125,7 +111,7 @@ public void streamingChat() { List responses = response.collectList().block(); System.out.println(responses); - assertThat(response).isNotNull(); + assertThat(responses).isNotNull(); assertThat(responses.stream() .filter(r -> r.message() != null) .map(r -> r.message().content()) @@ -138,19 +124,17 @@ public void streamingChat() { @Test public void embedText() { - EmbeddingsRequest request = new EmbeddingsRequest(MODEL, "I like to eat apples"); EmbeddingsResponse response = ollamaApi.embed(request); assertThat(response).isNotNull(); assertThat(response.embeddings()).hasSize(1); - assertThat(response.embeddings().get(0)).hasSize(3200); + assertThat(response.embeddings().get(0)).hasSize(3072); assertThat(response.model()).isEqualTo(MODEL); assertThat(response.promptEvalCount()).isEqualTo(5); assertThat(response.loadDuration()).isGreaterThan(1); assertThat(response.totalDuration()).isGreaterThan(1); - } } \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java new file mode 100644 index 00000000000..91752bdb6dc --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.ollama.api; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIf; +import org.springframework.ai.ollama.BaseOllamaIT; +import org.springframework.http.HttpStatus; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.io.IOException; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for the Ollama APIs to manage models. + * + * @author Thomas Vitale + */ +@Testcontainers +@DisabledIf("isDisabled") +public class OllamaApiModelsIT extends BaseOllamaIT { + + private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName(); + + static OllamaApi ollamaApi; + + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + ollamaApi = buildOllamaApiWithModel(MODEL); + } + + @Test + public void listModels() { + var listModelResponse = ollamaApi.listModels(); + + assertThat(listModelResponse).isNotNull(); + assertThat(listModelResponse.models().size()).isGreaterThan(0); + assertThat(listModelResponse.models().stream().anyMatch(model -> model.name().contains(MODEL))).isTrue(); + } + + @Test + public void showModel() { + var showModelRequest = new OllamaApi.ShowModelRequest(MODEL); + var showModelResponse = ollamaApi.showModel(showModelRequest); + + assertThat(showModelResponse).isNotNull(); + assertThat(showModelResponse.details().family()).isEqualTo("nomic-bert"); + } + + @Test + public void copyAndDeleteModel() { + var customModel = "schrodinger"; + var copyModelRequest = new OllamaApi.CopyModelRequest(MODEL, customModel); + var copyModelResponse = ollamaApi.copyModel(copyModelRequest); + assertThat(copyModelResponse.getStatusCode()).isEqualTo(HttpStatus.OK); + + var deleteModelRequest = new OllamaApi.DeleteModelRequest(customModel); + var deleteModelResponse = ollamaApi.deleteModel(deleteModelRequest); + assertThat(deleteModelResponse.getStatusCode()).isEqualTo(HttpStatus.OK); + } + + @Test + public void pullModel() { + var deleteModelRequest = new OllamaApi.DeleteModelRequest(MODEL); + var deleteModelResponse = ollamaApi.deleteModel(deleteModelRequest); + assertThat(deleteModelResponse.getStatusCode()).isEqualTo(HttpStatus.OK); + + var listModelResponse = ollamaApi.listModels(); + assertThat(listModelResponse.models().stream().anyMatch(model -> model.name().contains(MODEL))).isFalse(); + + var pullModelRequest = new OllamaApi.PullModelRequest(MODEL); + var progressResponse = ollamaApi.pullModel(pullModelRequest); + assertThat(progressResponse.status()).contains("success"); + + listModelResponse = ollamaApi.listModels(); + assertThat(listModelResponse.models().stream().anyMatch(model -> model.name().contains(MODEL))).isTrue(); + } + +} \ No newline at end of file diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java index 9355784ff71..77ee8f2ad15 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java @@ -40,34 +40,28 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @Testcontainers @DisabledIf("isDisabled") public class OllamaApiToolFunctionCallIT extends BaseOllamaIT { - private static final String MODEL = OllamaModel.MISTRAL.getName(); + private static final String MODEL = OllamaModel.LLAMA3_2.getName(); private static final Logger logger = LoggerFactory.getLogger(OllamaApiToolFunctionCallIT.class); MockWeatherService weatherService = new MockWeatherService(); - static String baseUrl = "http://localhost:11434"; + static OllamaApi ollamaApi; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL); - logger.info(MODEL + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + ollamaApi = buildOllamaApiWithModel(MODEL); } @SuppressWarnings("null") @Test public void toolFunctionCall() { - - OllamaApi completionApi = new OllamaApi(baseUrl); - // Step 1: send the conversation and available functions to the model var message = Message.builder(Role.USER) // .withContent("What's the weather like in San Francisco, Tokyo, and Paris? @@ -100,7 +94,7 @@ public void toolFunctionCall() { .withTools(List.of(functionTool)) .build(); - ChatResponse chatCompletion = completionApi.chat(chatCompletionRequest); + ChatResponse chatCompletion = ollamaApi.chat(chatCompletionRequest); assertThat(chatCompletion).isNotNull(); assertThat(chatCompletion.message()).isNotNull(); @@ -134,7 +128,7 @@ public void toolFunctionCall() { var functionResponseRequest = OllamaApi.ChatRequest.builder(MODEL).withMessages(messages).build(); - ChatResponse chatCompletion2 = completionApi.chat(functionResponseRequest); + ChatResponse chatCompletion2 = ollamaApi.chat(functionResponseRequest); logger.info("Final response: " + chatCompletion2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java new file mode 100644 index 00000000000..145d046d33e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java @@ -0,0 +1,49 @@ +package org.springframework.ai.autoconfigure.ollama; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.ollama.OllamaContainer; + +import java.io.IOException; + +public class BaseOllamaIT { + + private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class); + + // Toggle for running tests locally on native Ollama for a faster feedback loop. + private static final boolean useTestcontainers = true; + + public static final OllamaContainer ollamaContainer; + + static { + ollamaContainer = new OllamaContainer(OllamaImage.IMAGE).withReuse(true); + ollamaContainer.start(); + } + + /** + * Change the return value to false in order to run multiple Ollama IT tests locally + * reusing the same container image. + * + * Also, add the entry + * + * testcontainers.reuse.enable=true + * + * to the file ".testcontainers.properties" located in your home directory + */ + public static boolean isDisabled() { + return true; + } + + public static String buildConnectionWithModel(String model) throws IOException, InterruptedException { + var baseUrl = "http://localhost:11434"; + if (useTestcontainers) { + baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + + logger.info("Start pulling the '{}' model. The operation can take several minutes...", model); + ollamaContainer.execInContainer("ollama", "pull", model); + logger.info("Completed pulling the '{}' model", model); + } + return baseUrl; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java index 2497001d935..8048bfc2487 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java @@ -18,69 +18,42 @@ import static org.assertj.core.api.Assertions.assertThat; import java.io.IOException; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIf; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.DockerClientFactory; -import org.testcontainers.containers.GenericContainer; import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; - -import com.github.dockerjava.api.DockerClient; -import com.github.dockerjava.api.command.InspectContainerResponse; -import com.github.dockerjava.api.model.Image; import reactor.core.publisher.Flux; /** * @author Christian Tzolov * @author EddĂș MelĂ©ndez + * @author Thomas Vitale * @since 0.8.0 */ -@Disabled("For manual smoke testing only.") @Testcontainers -public class OllamaChatAutoConfigurationIT { - - private static final Log logger = LogFactory.getLog(OllamaChatAutoConfigurationIT.class); - - private static final String MODEL_NAME = "mistral"; +@DisabledIf("isDisabled") +public class OllamaChatAutoConfigurationIT extends BaseOllamaIT { - private static final String OLLAMA_WITH_MODEL = "%s-%s".formatted(MODEL_NAME, OllamaImage.IMAGE); - - private static OllamaContainer ollamaContainer; - - static { - ollamaContainer = new OllamaContainer(OllamaDockerImageName.image()); - ollamaContainer.start(); - createImage(ollamaContainer, OLLAMA_WITH_MODEL); - } + private static final String MODEL_NAME = OllamaModel.LLAMA3_2.getName(); - static String baseUrl = "http://localhost:11434"; + static String baseUrl; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); - logger.info(MODEL_NAME + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + baseUrl = buildConnectionWithModel(MODEL_NAME); } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( @@ -92,22 +65,14 @@ public static void beforeAll() throws IOException, InterruptedException { // @formatter:on .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); - private final Message systemMessage = new SystemPromptTemplate(""" - You are a helpful AI assistant. Your name is {name}. - You are an AI assistant that helps people find information. - Your name is {name} - You should reply to the user's request with your name and also in the style of a {voice}. - """).createMessage(Map.of("name", "Bob", "voice", "pirate")); - - private final UserMessage userMessage = new UserMessage( - "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + private final UserMessage userMessage = new UserMessage("What's the capital of Denmark?"); @Test public void chatCompletion() { contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + ChatResponse response = chatModel.call(new Prompt(userMessage)); + assertThat(response.getResult().getOutput().getContent()).contains("Copenhagen"); }); } @@ -117,7 +82,7 @@ public void chatCompletionStreaming() { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(userMessage)); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -129,7 +94,7 @@ public void chatCompletionStreaming() { .map(AssistantMessage::getContent) .collect(Collectors.joining()); - assertThat(stitchedResponseContent).contains("Blackbeard"); + assertThat(stitchedResponseContent).contains("Copenhagen"); }); } @@ -151,72 +116,4 @@ void chatActivation() { }); } - static class OllamaContainer extends GenericContainer { - - private final DockerImageName dockerImageName; - - OllamaContainer(DockerImageName image) { - super(image); - this.dockerImageName = image; - withExposedPorts(11434); - withImagePullPolicy(dockerImageName -> !dockerImageName.getUnversionedPart().startsWith(MODEL_NAME)); - } - - @Override - protected void containerIsStarted(InspectContainerResponse containerInfo) { - if (!this.dockerImageName.getVersionPart().endsWith(MODEL_NAME)) { - try { - execInContainer("ollama", "pull", MODEL_NAME); - } - catch (IOException | InterruptedException e) { - throw new RuntimeException("Error pulling orca-mini model", e); - } - } - } - - } - - public static void createImage(GenericContainer container, String localImageName) { - DockerImageName dockerImageName = DockerImageName.parse(container.getDockerImageName()); - if (!dockerImageName.equals(DockerImageName.parse(localImageName))) { - DockerClient dockerClient = DockerClientFactory.instance().client(); - List images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec(); - if (images.isEmpty()) { - DockerImageName imageModel = DockerImageName.parse(localImageName); - dockerClient.commitCmd(container.getContainerId()) - .withRepository(imageModel.getUnversionedPart()) - .withLabels(Collections.singletonMap("org.testcontainers.sessionId", "")) - .withTag(imageModel.getVersionPart()) - .exec(); - } - } - } - - public static class OllamaDockerImageName { - - private final String baseImage; - - private final String localImageName; - - OllamaDockerImageName(String baseImage, String localImageName) { - this.baseImage = baseImage; - this.localImageName = localImageName; - } - - public static DockerImageName image() { - return new OllamaDockerImageName(OllamaImage.IMAGE, OLLAMA_WITH_MODEL).resolve(); - } - - private DockerImageName resolve() { - var dockerImageName = DockerImageName.parse(this.baseImage); - var dockerClient = DockerClientFactory.instance().client(); - var images = dockerClient.listImagesCmd().withReferenceFilter(this.localImageName).exec(); - if (images.isEmpty()) { - return dockerImageName; - } - return DockerImageName.parse(this.localImageName); - } - - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index fd6106ec280..9e1bd8b8dec 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -19,11 +19,9 @@ import java.util.List; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testcontainers.junit.jupiter.Container; +import org.junit.jupiter.api.condition.DisabledIf; +import org.springframework.ai.ollama.api.OllamaModel; import org.testcontainers.junit.jupiter.Testcontainers; import org.springframework.ai.embedding.EmbeddingResponse; @@ -31,7 +29,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.ollama.OllamaContainer; import static org.assertj.core.api.Assertions.assertThat; @@ -40,26 +37,17 @@ * @author Thomas Vitale * @since 0.8.0 */ -@Disabled("For manual smoke testing only.") @Testcontainers -public class OllamaEmbeddingAutoConfigurationIT { +@DisabledIf("isDisabled") +public class OllamaEmbeddingAutoConfigurationIT extends BaseOllamaIT { - private static final Logger logger = LoggerFactory.getLogger(OllamaEmbeddingAutoConfigurationIT.class); + private static final String MODEL_NAME = OllamaModel.NOMIC_EMBED_TEXT.getName(); - private static final String MODEL_NAME = "orca-mini"; - - @Container - static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.IMAGE); - - static String baseUrl = "http://localhost:11434"; + static String baseUrl; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); - logger.info(MODEL_NAME + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + baseUrl = buildConnectionWithModel(MODEL_NAME); } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() @@ -75,7 +63,7 @@ public void singleTextEmbedding() { EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(3200); + assertThat(embeddingModel.dimensions()).isEqualTo(768); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java index 873c3e2ef09..d4b621ee768 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java @@ -17,6 +17,6 @@ public class OllamaImage { - public static final String IMAGE = "ollama/ollama:0.3.9"; + public static final String IMAGE = "ollama/ollama:0.3.13"; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java index 41e2e7dd5d8..bc5604ed4e5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java @@ -24,10 +24,11 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; -import org.springframework.ai.autoconfigure.ollama.OllamaImage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -35,35 +36,27 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; import reactor.core.publisher.Flux; -@Disabled("For manual smoke testing only.") @Testcontainers -public class FunctionCallbackInPromptIT { +@DisabledIf("isDisabled") +public class FunctionCallbackInPromptIT extends BaseOllamaIT { private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class); - private static String MODEL_NAME = "mistral"; + private static final String MODEL_NAME = OllamaModel.LLAMA3_2.getName(); - @Container - static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.IMAGE); - - static String baseUrl = "http://localhost:11434"; + static String baseUrl; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); - logger.info(MODEL_NAME + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + baseUrl = buildConnectionWithModel(MODEL_NAME); } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( @@ -103,7 +96,6 @@ void functionCallTest() { @Disabled("Ollama API does not support streaming function calls yet") @Test void streamingFunctionCallTest() { - contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java index 0064514e093..9c887d47ce0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java @@ -21,13 +21,14 @@ import java.util.List; import java.util.stream.Collectors; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; -import org.springframework.ai.autoconfigure.ollama.OllamaImage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -38,37 +39,29 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; import reactor.core.publisher.Flux; -@Disabled("For manual smoke testing only.") @Testcontainers -public class FunctionCallbackWrapperIT { +@DisabledIf("isDisabled") +public class FunctionCallbackWrapperIT extends BaseOllamaIT { - private static final Log logger = LogFactory.getLog(FunctionCallbackWrapperIT.class); + private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackWrapperIT.class); - private static String MODEL_NAME = "mistral"; + private static final String MODEL_NAME = OllamaModel.LLAMA3_2.getName(); - @Container - static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.IMAGE); - - static String baseUrl = "http://localhost:11434"; + static String baseUrl; @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); - ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); - logger.info(MODEL_NAME + " pulling competed!"); - - baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); + baseUrl = buildConnectionWithModel(MODEL_NAME); } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( @@ -96,7 +89,6 @@ void functionCallTest() { logger.info("Response: " + response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java index d0382005b85..6e139b74796 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java @@ -15,11 +15,11 @@ */ package org.springframework.ai.testcontainers.service.connection.ollama; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.OllamaEmbeddingModel; @@ -44,15 +44,15 @@ * @author Thomas Vitale */ @SpringJUnitConfig -@Disabled("requires more memory than is often available on dev machines") +@Disabled("Slow on CPU. Only run manually.") @Testcontainers @TestPropertySource(properties = "spring.ai.ollama.embedding.options.model=" + OllamaContainerConnectionDetailsFactoryTest.MODEL_NAME) class OllamaContainerConnectionDetailsFactoryTest { - private static final Log logger = LogFactory.getLog(OllamaContainerConnectionDetailsFactoryTest.class); + private static final Logger logger = LoggerFactory.getLogger(OllamaContainerConnectionDetailsFactoryTest.class); - static final String MODEL_NAME = "orca-mini"; + static final String MODEL_NAME = "nomic-embed-text"; @Container @ServiceConnection @@ -63,9 +63,9 @@ class OllamaContainerConnectionDetailsFactoryTest { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); + logger.info("Start pulling the '{}' model. The operation can take several minutes...", MODEL_NAME); ollama.execInContainer("ollama", "pull", MODEL_NAME); - logger.info(MODEL_NAME + " pulling competed!"); + logger.info("Completed pulling the '{}' model", MODEL_NAME); } @Test @@ -73,7 +73,7 @@ public void singleTextEmbedding() { EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(this.embeddingModel.dimensions()).isEqualTo(3200); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @Configuration(proxyBeanMethods = false) diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java index e5d2ebcd75d..ed9a1aaab8a 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java @@ -22,6 +22,6 @@ */ public class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.9"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.13"); }