Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,66 +18,88 @@

import java.time.Duration;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.ollama.OllamaContainer;

import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class BaseOllamaIT {
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OLLAMA_TESTS_ENABLED", matches = "true")
public abstract class BaseOllamaIT {

// Toggle for running tests locally on native Ollama for a faster feedback loop.
private static final boolean useTestcontainers = true;
private static final String OLLAMA_LOCAL_URL = "http://localhost:11434";

public static OllamaContainer ollamaContainer;
private static final Duration DEFAULT_TIMEOUT = Duration.ofMinutes(10);

static {
if (useTestcontainers) {
private static final int DEFAULT_MAX_RETRIES = 2;

// Environment variable to control whether to create a new container or use existing
// Ollama instance
private static final boolean SKIP_CONTAINER_CREATION = Boolean
.parseBoolean(System.getenv().getOrDefault("OLLAMA_SKIP_CONTAINER", "false"));

private static OllamaContainer ollamaContainer;

private static final ThreadLocal<OllamaApi> ollamaApi = new ThreadLocal<>();

/**
* Initialize the Ollama container and API with the specified model. This method
* should be called from @BeforeAll in subclasses.
* @param model the Ollama model to initialize (must not be null or empty)
* @return configured OllamaApi instance
* @throws IllegalArgumentException if model is null or empty
*/
protected static OllamaApi initializeOllama(final String model) {
Assert.hasText(model, "Model name must be provided");

if (!SKIP_CONTAINER_CREATION) {
ollamaContainer = new OllamaContainer(OllamaImage.DEFAULT_IMAGE).withReuse(true);
ollamaContainer.start();
}

final OllamaApi api = buildOllamaApiWithModel(model);
ollamaApi.set(api);
return api;
}

/**
* Change the return value to false in order to run multiple Ollama IT tests locally
* reusing the same container image.
*
* Also, add the entry
*
* testcontainers.reuse.enable=true
*
* to the file ".testcontainers.properties" located in your home directory
* Get the initialized OllamaApi instance.
* @return the OllamaApi instance
* @throws IllegalStateException if called before initialization
*/
public static boolean isDisabled() {
return true;
}

public static OllamaApi buildOllamaApi() {
return buildOllamaApiWithModel(null);
protected static OllamaApi getOllamaApi() {
OllamaApi api = ollamaApi.get();
Assert.state(api != null, "OllamaApi not initialized. Call initializeOllama first.");
return api;
}

public static OllamaApi buildOllamaApiWithModel(String model) {
var baseUrl = "http://localhost:11434";
if (useTestcontainers) {
baseUrl = ollamaContainer.getEndpoint();
}
var ollamaApi = new OllamaApi(baseUrl);

if (StringUtils.hasText(model)) {
ensureModelIsPresent(ollamaApi, model);
@AfterAll
public static void tearDown() {
if (ollamaContainer != null) {
ollamaContainer.stop();
}
}

return ollamaApi;
private static OllamaApi buildOllamaApiWithModel(final String model) {
final String baseUrl = SKIP_CONTAINER_CREATION ? OLLAMA_LOCAL_URL : ollamaContainer.getEndpoint();
final OllamaApi api = new OllamaApi(baseUrl);
ensureModelIsPresent(api, model);
return api;
}

public static void ensureModelIsPresent(OllamaApi ollamaApi, String model) {
var modelManagementOptions = ModelManagementOptions.builder()
.withMaxRetries(2)
.withTimeout(Duration.ofMinutes(10))
private static void ensureModelIsPresent(final OllamaApi ollamaApi, final String model) {
final var modelManagementOptions = ModelManagementOptions.builder()
.withMaxRetries(DEFAULT_MAX_RETRIES)
.withTimeout(DEFAULT_TIMEOUT)
.build();
var ollamaModelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
final var ollamaModelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
ollamaModelManager.pullModel(model, PullModelStrategy.WHEN_MISSING);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@

import static org.assertj.core.api.Assertions.assertThat;

@Testcontainers
@SpringBootTest(classes = OllamaChatModelFunctionCallingIT.Config.class)
@DisabledIf("isDisabled")
class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {

private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class);
Expand Down Expand Up @@ -120,7 +118,7 @@ static class Config {

@Bean
public OllamaApi ollamaApi() {
return buildOllamaApiWithModel(MODEL);
return initializeOllama(MODEL);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@
import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest
@Testcontainers
@DisabledIf("isDisabled")
class OllamaChatModelIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.LLAMA3_2.getName();
Expand Down Expand Up @@ -241,7 +239,7 @@ public static class TestConfiguration {

@Bean
public OllamaApi ollamaApi() {
return buildOllamaApiWithModel(MODEL);
return initializeOllama(MODEL);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;

@SpringBootTest
@Testcontainers
@DisabledIf("isDisabled")
class OllamaChatModelMultimodalIT extends BaseOllamaIT {

private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelMultimodalIT.class);
Expand Down Expand Up @@ -80,7 +78,7 @@ public static class TestConfiguration {

@Bean
public OllamaApi ollamaApi() {
return buildOllamaApiWithModel(MODEL);
return initializeOllama(MODEL);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
* @author Thomas Vitale
*/
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
@DisabledIf("isDisabled")
public class OllamaChatModelObservationIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.LLAMA3_2.getName();
Expand Down Expand Up @@ -166,7 +165,7 @@ public TestObservationRegistry observationRegistry() {

@Bean
public OllamaApi openAiApi() {
return buildOllamaApiWithModel(MODEL);
return initializeOllama(MODEL);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest
@DisabledIf("isDisabled")
@Testcontainers
class OllamaEmbeddingModelIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
Expand Down Expand Up @@ -100,7 +98,7 @@ public static class TestConfiguration {

@Bean
public OllamaApi ollamaApi() {
return buildOllamaApiWithModel(MODEL);
return initializeOllama(MODEL);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
* @author Thomas Vitale
*/
@SpringBootTest(classes = OllamaEmbeddingModelObservationIT.Config.class)
@DisabledIf("isDisabled")
public class OllamaEmbeddingModelObservationIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
Expand Down Expand Up @@ -100,7 +99,7 @@ public TestObservationRegistry observationRegistry() {

@Bean
public OllamaApi openAiApi() {
return buildOllamaApiWithModel(MODEL);
return initializeOllama(MODEL);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,13 @@
* @author Christian Tzolov
* @author Thomas Vitale
*/
@Testcontainers
@DisabledIf("isDisabled")
public class OllamaApiIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.LLAMA3_2.getName();

static OllamaApi ollamaApi;

@BeforeAll
public static void beforeAll() throws IOException, InterruptedException {
ollamaApi = buildOllamaApiWithModel(MODEL);
initializeOllama(MODEL);
}

@Test
Expand All @@ -63,7 +59,7 @@ public void generation() {
.withStream(false)
.build();

GenerateResponse response = ollamaApi.generate(request);
GenerateResponse response = getOllamaApi().generate(request);

System.out.println(response);

Expand All @@ -87,7 +83,7 @@ public void chat() {
.withOptions(OllamaOptions.create().withTemperature(0.9))
.build();

ChatResponse response = ollamaApi.chat(request);
ChatResponse response = getOllamaApi().chat(request);

System.out.println(response);

Expand All @@ -108,7 +104,7 @@ public void streamingChat() {
.withOptions(OllamaOptions.create().withTemperature(0.9).toMap())
.build();

Flux<ChatResponse> response = ollamaApi.streamingChat(request);
Flux<ChatResponse> response = getOllamaApi().streamingChat(request);

List<ChatResponse> responses = response.collectList().block();
System.out.println(responses);
Expand All @@ -128,7 +124,7 @@ public void streamingChat() {
public void embedText() {
EmbeddingsRequest request = new EmbeddingsRequest(MODEL, "I like to eat apples");

EmbeddingsResponse response = ollamaApi.embed(request);
EmbeddingsResponse response = getOllamaApi().embed(request);

assertThat(response).isNotNull();
assertThat(response.embeddings()).hasSize(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
*
* @author Thomas Vitale
*/
@Testcontainers
@DisabledIf("isDisabled")
public class OllamaApiModelsIT extends BaseOllamaIT {

private static final String MODEL = "all-minilm";
Expand All @@ -44,7 +42,7 @@ public class OllamaApiModelsIT extends BaseOllamaIT {

@BeforeAll
public static void beforeAll() throws IOException, InterruptedException {
ollamaApi = buildOllamaApiWithModel(MODEL);
ollamaApi = initializeOllama(MODEL);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
* @author Christian Tzolov
* @author Thomas Vitale
*/
@Testcontainers
@DisabledIf("isDisabled")
public class OllamaApiToolFunctionCallIT extends BaseOllamaIT {

private static final String MODEL = "qwen2.5:3b";
Expand All @@ -56,7 +54,7 @@ public class OllamaApiToolFunctionCallIT extends BaseOllamaIT {

@BeforeAll
public static void beforeAll() throws IOException, InterruptedException {
ollamaApi = buildOllamaApiWithModel(MODEL);
ollamaApi = initializeOllama(MODEL);
}

@SuppressWarnings("null")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
*
* @author Thomas Vitale
*/
@Testcontainers
@DisabledIf("isDisabled")
class OllamaModelManagerIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
Expand All @@ -45,7 +43,7 @@ class OllamaModelManagerIT extends BaseOllamaIT {

@BeforeAll
public static void beforeAll() throws IOException, InterruptedException {
var ollamaApi = buildOllamaApiWithModel(MODEL);
var ollamaApi = initializeOllama(MODEL);
modelManager = new OllamaModelManager(ollamaApi);
}

Expand Down Expand Up @@ -144,7 +142,7 @@ public void pullAdditionalModels() {
var isModelAvailable = modelManager.isModelAvailable(model);
assertThat(isModelAvailable).isFalse();

new OllamaModelManager(buildOllamaApi(),
new OllamaModelManager(getOllamaApi(),
new ModelManagementOptions(PullModelStrategy.WHEN_MISSING, List.of(model), Duration.ofMinutes(5), 0));

isModelAvailable = modelManager.isModelAvailable(model);
Expand Down
Loading