Skip to content

Commit 21fc653

Browse files
markpollacksobychacko
authored andcommitted
Improve Ollama test container management
This change simplifies how we manage Ollama containers in tests by moving from manual toggles to environment variables for better control. Instead of scattered container configuration, we now have: - OLLAMA_WITH_REUSE: Toggle reuse of existing containers between tests - OLLAMA_TESTS_ENABLED: Control test execution globally The motivation is to make tests more reliable and easier to maintain. Previously, developers had to modify code to run tests locally vs CI. Now they can control this via environment variables. We also introduce thread-safe API access and consistent default settings across all test classes, removing duplicated configuration and potential resource leaks. This makes the test infrastructure more maintainable and provides clearer separation between local development and CI environments. Checkstyle fixes. Make buildOllamaApiWithModel in BaseOllamaIT public and the related test changes.
1 parent c3c95a8 commit 21fc653

19 files changed

+144
-155
lines changed

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,66 +18,87 @@
1818

1919
import java.time.Duration;
2020

21+
import org.junit.jupiter.api.AfterAll;
22+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23+
import org.testcontainers.junit.jupiter.Testcontainers;
2124
import org.testcontainers.ollama.OllamaContainer;
2225

2326
import org.springframework.ai.ollama.api.OllamaApi;
2427
import org.springframework.ai.ollama.management.ModelManagementOptions;
2528
import org.springframework.ai.ollama.management.OllamaModelManager;
2629
import org.springframework.ai.ollama.management.PullModelStrategy;
27-
import org.springframework.util.StringUtils;
30+
import org.springframework.util.Assert;
2831

29-
public class BaseOllamaIT {
32+
@Testcontainers
33+
@EnabledIfEnvironmentVariable(named = "OLLAMA_TESTS_ENABLED", matches = "true")
34+
public abstract class BaseOllamaIT {
3035

31-
// Toggle for running tests locally on native Ollama for a faster feedback loop.
32-
private static final boolean useTestcontainers = true;
36+
private static final String OLLAMA_LOCAL_URL = "http://localhost:11434";
3337

34-
public static OllamaContainer ollamaContainer;
38+
private static final Duration DEFAULT_TIMEOUT = Duration.ofMinutes(10);
3539

36-
static {
37-
if (useTestcontainers) {
40+
private static final int DEFAULT_MAX_RETRIES = 2;
41+
42+
// Environment variable to control whether to create a new container or use existing
43+
// Ollama instance
44+
private static final boolean SKIP_CONTAINER_CREATION = Boolean
45+
.parseBoolean(System.getenv().getOrDefault("OLLAMA_SKIP_CONTAINER", "false"));
46+
47+
private static OllamaContainer ollamaContainer;
48+
49+
private static final ThreadLocal<OllamaApi> ollamaApi = new ThreadLocal<>();
50+
51+
/**
52+
* Initialize the Ollama container and API with the specified model. This method
53+
* should be called from @BeforeAll in subclasses.
54+
* @param model the Ollama model to initialize (must not be null or empty)
55+
* @return configured OllamaApi instance
56+
* @throws IllegalArgumentException if model is null or empty
57+
*/
58+
protected static OllamaApi initializeOllama(final String model) {
59+
Assert.hasText(model, "Model name must be provided");
60+
61+
if (!SKIP_CONTAINER_CREATION) {
3862
ollamaContainer = new OllamaContainer(OllamaImage.DEFAULT_IMAGE).withReuse(true);
3963
ollamaContainer.start();
4064
}
65+
66+
final OllamaApi api = buildOllamaApiWithModel(model);
67+
ollamaApi.set(api);
68+
return api;
4169
}
4270

4371
/**
44-
* Change the return value to false in order to run multiple Ollama IT tests locally
45-
* reusing the same container image.
46-
*
47-
* Also, add the entry
48-
*
49-
* testcontainers.reuse.enable=true
50-
*
51-
* to the file ".testcontainers.properties" located in your home directory
72+
* Get the initialized OllamaApi instance.
73+
* @return the OllamaApi instance
74+
* @throws IllegalStateException if called before initialization
5275
*/
53-
public static boolean isDisabled() {
54-
return true;
55-
}
56-
57-
public static OllamaApi buildOllamaApi() {
58-
return buildOllamaApiWithModel(null);
76+
protected static OllamaApi getOllamaApi() {
77+
OllamaApi api = ollamaApi.get();
78+
Assert.state(api != null, "OllamaApi not initialized. Call initializeOllama first.");
79+
return api;
5980
}
6081

61-
public static OllamaApi buildOllamaApiWithModel(String model) {
62-
var baseUrl = "http://localhost:11434";
63-
if (useTestcontainers) {
64-
baseUrl = ollamaContainer.getEndpoint();
65-
}
66-
var ollamaApi = new OllamaApi(baseUrl);
67-
68-
if (StringUtils.hasText(model)) {
69-
ensureModelIsPresent(ollamaApi, model);
82+
@AfterAll
83+
public static void tearDown() {
84+
if (ollamaContainer != null) {
85+
ollamaContainer.stop();
7086
}
87+
}
7188

72-
return ollamaApi;
89+
private static OllamaApi buildOllamaApiWithModel(final String model) {
90+
final String baseUrl = SKIP_CONTAINER_CREATION ? OLLAMA_LOCAL_URL : ollamaContainer.getEndpoint();
91+
final OllamaApi api = new OllamaApi(baseUrl);
92+
ensureModelIsPresent(api, model);
93+
return api;
7394
}
7495

75-
public static void ensureModelIsPresent(OllamaApi ollamaApi, String model) {
76-
var modelManagementOptions = ModelManagementOptions.builder()
77-
.withMaxRetries(2)
78-
.withTimeout(Duration.ofMinutes(10))
96+
private static void ensureModelIsPresent(final OllamaApi ollamaApi, final String model) {
97+
final var modelManagementOptions = ModelManagementOptions.builder()
98+
.withMaxRetries(DEFAULT_MAX_RETRIES)
99+
.withTimeout(DEFAULT_TIMEOUT)
79100
.build();
80-
var ollamaModelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
101+
final var ollamaModelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);
81102
ollamaModelManager.pullModel(model, PullModelStrategy.WHEN_MISSING);
82103
}
83104

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222

2323
import org.junit.jupiter.api.Disabled;
2424
import org.junit.jupiter.api.Test;
25-
import org.junit.jupiter.api.condition.DisabledIf;
2625
import org.slf4j.Logger;
2726
import org.slf4j.LoggerFactory;
28-
import org.testcontainers.junit.jupiter.Testcontainers;
2927
import reactor.core.publisher.Flux;
3028

3129
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -46,9 +44,7 @@
4644

4745
import static org.assertj.core.api.Assertions.assertThat;
4846

49-
@Testcontainers
5047
@SpringBootTest(classes = OllamaChatModelFunctionCallingIT.Config.class)
51-
@DisabledIf("isDisabled")
5248
class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
5349

5450
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class);
@@ -120,7 +116,7 @@ static class Config {
120116

121117
@Bean
122118
public OllamaApi ollamaApi() {
123-
return buildOllamaApiWithModel(MODEL);
119+
return initializeOllama(MODEL);
124120
}
125121

126122
@Bean

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import java.util.stream.Collectors;
2222

2323
import org.junit.jupiter.api.Test;
24-
import org.junit.jupiter.api.condition.DisabledIf;
25-
import org.testcontainers.junit.jupiter.Testcontainers;
2624

2725
import org.springframework.ai.chat.client.ChatClient;
2826
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -53,8 +51,6 @@
5351
import static org.assertj.core.api.Assertions.assertThat;
5452

5553
@SpringBootTest
56-
@Testcontainers
57-
@DisabledIf("isDisabled")
5854
class OllamaChatModelIT extends BaseOllamaIT {
5955

6056
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
@@ -241,7 +237,7 @@ public static class TestConfiguration {
241237

242238
@Bean
243239
public OllamaApi ollamaApi() {
244-
return buildOllamaApiWithModel(MODEL);
240+
return initializeOllama(MODEL);
245241
}
246242

247243
@Bean

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
import java.util.List;
2020

2121
import org.junit.jupiter.api.Test;
22-
import org.junit.jupiter.api.condition.DisabledIf;
2322
import org.slf4j.Logger;
2423
import org.slf4j.LoggerFactory;
25-
import org.testcontainers.junit.jupiter.Testcontainers;
2624

2725
import org.springframework.ai.chat.messages.UserMessage;
2826
import org.springframework.ai.chat.prompt.Prompt;
@@ -40,8 +38,6 @@
4038
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
4139

4240
@SpringBootTest
43-
@Testcontainers
44-
@DisabledIf("isDisabled")
4541
class OllamaChatModelMultimodalIT extends BaseOllamaIT {
4642

4743
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelMultimodalIT.class);
@@ -80,7 +76,7 @@ public static class TestConfiguration {
8076

8177
@Bean
8278
public OllamaApi ollamaApi() {
83-
return buildOllamaApiWithModel(MODEL);
79+
return initializeOllama(MODEL);
8480
}
8581

8682
@Bean

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import io.micrometer.observation.tck.TestObservationRegistryAssert;
2424
import org.junit.jupiter.api.BeforeEach;
2525
import org.junit.jupiter.api.Test;
26-
import org.junit.jupiter.api.condition.DisabledIf;
2726
import reactor.core.publisher.Flux;
2827

2928
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
@@ -50,7 +49,6 @@
5049
* @author Thomas Vitale
5150
*/
5251
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
53-
@DisabledIf("isDisabled")
5452
public class OllamaChatModelObservationIT extends BaseOllamaIT {
5553

5654
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
@@ -166,7 +164,7 @@ public TestObservationRegistry observationRegistry() {
166164

167165
@Bean
168166
public OllamaApi openAiApi() {
169-
return buildOllamaApiWithModel(MODEL);
167+
return initializeOllama(MODEL);
170168
}
171169

172170
@Bean

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import java.util.List;
2020

2121
import org.junit.jupiter.api.Test;
22-
import org.junit.jupiter.api.condition.DisabledIf;
23-
import org.testcontainers.junit.jupiter.Testcontainers;
2422

2523
import org.springframework.ai.embedding.EmbeddingRequest;
2624
import org.springframework.ai.embedding.EmbeddingResponse;
@@ -38,8 +36,6 @@
3836
import static org.assertj.core.api.Assertions.assertThat;
3937

4038
@SpringBootTest
41-
@DisabledIf("isDisabled")
42-
@Testcontainers
4339
class OllamaEmbeddingModelIT extends BaseOllamaIT {
4440

4541
private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
@@ -100,7 +96,7 @@ public static class TestConfiguration {
10096

10197
@Bean
10298
public OllamaApi ollamaApi() {
103-
return buildOllamaApiWithModel(MODEL);
99+
return initializeOllama(MODEL);
104100
}
105101

106102
@Bean

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import io.micrometer.observation.tck.TestObservationRegistry;
2222
import io.micrometer.observation.tck.TestObservationRegistryAssert;
2323
import org.junit.jupiter.api.Test;
24-
import org.junit.jupiter.api.condition.DisabledIf;
2524

2625
import org.springframework.ai.embedding.EmbeddingRequest;
2726
import org.springframework.ai.embedding.EmbeddingResponse;
@@ -47,7 +46,6 @@
4746
* @author Thomas Vitale
4847
*/
4948
@SpringBootTest(classes = OllamaEmbeddingModelObservationIT.Config.class)
50-
@DisabledIf("isDisabled")
5149
public class OllamaEmbeddingModelObservationIT extends BaseOllamaIT {
5250

5351
private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
@@ -100,7 +98,7 @@ public TestObservationRegistry observationRegistry() {
10098

10199
@Bean
102100
public OllamaApi openAiApi() {
103-
return buildOllamaApiWithModel(MODEL);
101+
return initializeOllama(MODEL);
104102
}
105103

106104
@Bean

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323
import org.junit.jupiter.api.BeforeAll;
2424
import org.junit.jupiter.api.Test;
25-
import org.junit.jupiter.api.condition.DisabledIf;
26-
import org.testcontainers.junit.jupiter.Testcontainers;
2725
import reactor.core.publisher.Flux;
2826

2927
import org.springframework.ai.ollama.BaseOllamaIT;
@@ -42,17 +40,13 @@
4240
* @author Christian Tzolov
4341
* @author Thomas Vitale
4442
*/
45-
@Testcontainers
46-
@DisabledIf("isDisabled")
4743
public class OllamaApiIT extends BaseOllamaIT {
4844

4945
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
5046

51-
static OllamaApi ollamaApi;
52-
5347
@BeforeAll
5448
public static void beforeAll() throws IOException, InterruptedException {
55-
ollamaApi = buildOllamaApiWithModel(MODEL);
49+
initializeOllama(MODEL);
5650
}
5751

5852
@Test
@@ -63,7 +57,7 @@ public void generation() {
6357
.withStream(false)
6458
.build();
6559

66-
GenerateResponse response = ollamaApi.generate(request);
60+
GenerateResponse response = getOllamaApi().generate(request);
6761

6862
System.out.println(response);
6963

@@ -87,7 +81,7 @@ public void chat() {
8781
.withOptions(OllamaOptions.create().withTemperature(0.9))
8882
.build();
8983

90-
ChatResponse response = ollamaApi.chat(request);
84+
ChatResponse response = getOllamaApi().chat(request);
9185

9286
System.out.println(response);
9387

@@ -108,7 +102,7 @@ public void streamingChat() {
108102
.withOptions(OllamaOptions.create().withTemperature(0.9).toMap())
109103
.build();
110104

111-
Flux<ChatResponse> response = ollamaApi.streamingChat(request);
105+
Flux<ChatResponse> response = getOllamaApi().streamingChat(request);
112106

113107
List<ChatResponse> responses = response.collectList().block();
114108
System.out.println(responses);
@@ -128,7 +122,7 @@ public void streamingChat() {
128122
public void embedText() {
129123
EmbeddingsRequest request = new EmbeddingsRequest(MODEL, "I like to eat apples");
130124

131-
EmbeddingsResponse response = ollamaApi.embed(request);
125+
EmbeddingsResponse response = getOllamaApi().embed(request);
132126

133127
assertThat(response).isNotNull();
134128
assertThat(response.embeddings()).hasSize(1);

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
import org.junit.jupiter.api.BeforeAll;
2323
import org.junit.jupiter.api.Test;
24-
import org.junit.jupiter.api.condition.DisabledIf;
25-
import org.testcontainers.junit.jupiter.Testcontainers;
2624

2725
import org.springframework.ai.ollama.BaseOllamaIT;
2826
import org.springframework.http.HttpStatus;
@@ -34,8 +32,6 @@
3432
*
3533
* @author Thomas Vitale
3634
*/
37-
@Testcontainers
38-
@DisabledIf("isDisabled")
3935
public class OllamaApiModelsIT extends BaseOllamaIT {
4036

4137
private static final String MODEL = "all-minilm";
@@ -44,7 +40,7 @@ public class OllamaApiModelsIT extends BaseOllamaIT {
4440

4541
@BeforeAll
4642
public static void beforeAll() throws IOException, InterruptedException {
47-
ollamaApi = buildOllamaApiWithModel(MODEL);
43+
ollamaApi = initializeOllama(MODEL);
4844
}
4945

5046
@Test

0 commit comments

Comments
 (0)