Skip to content

Commit 4cb1f7a

Browse files
committed
Add tests
- Verify the mappingJson field is correctly set - verify the override works fine - Add integration tests with Ollama embedding model
1 parent c66dfbf commit 4cb1f7a

File tree

3 files changed

+204
-0
lines changed

3 files changed

+204
-0
lines changed

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ public void addAndSearchTest() {
8989
this.contextRunner.run(context -> {
9090
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
9191
TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class);
92+
assertThat(vectorStore).isNotNull();
93+
assertThat(vectorStore).hasFieldOrPropertyWithValue("mappingJson", """
94+
{
95+
"properties":{
96+
"embedding":{
97+
"type":"knn_vector",
98+
"dimension":384
99+
}
100+
}
101+
}""");
92102

93103
vectorStore.add(this.documents);
94104

vector-stores/spring-ai-opensearch-store/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@
6868
<scope>test</scope>
6969
</dependency>
7070

71+
<dependency>
72+
<groupId>org.springframework.ai</groupId>
73+
<artifactId>spring-ai-ollama</artifactId>
74+
<version>${parent.version}</version>
75+
<scope>test</scope>
76+
</dependency>
77+
7178

7279
<dependency>
7380
<groupId>org.springframework.ai</groupId>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore;
18+
19+
import java.io.IOException;
20+
import java.net.URISyntaxException;
21+
import java.nio.charset.StandardCharsets;
22+
import java.time.Duration;
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.concurrent.TimeUnit;
26+
27+
import org.apache.hc.core5.http.HttpHost;
28+
import org.awaitility.Awaitility;
29+
import org.junit.jupiter.api.BeforeAll;
30+
import org.junit.jupiter.api.BeforeEach;
31+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
32+
import org.junit.jupiter.params.ParameterizedTest;
33+
import org.junit.jupiter.params.provider.ValueSource;
34+
import org.opensearch.client.opensearch.OpenSearchClient;
35+
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;
36+
import org.opensearch.testcontainers.OpensearchContainer;
37+
import org.testcontainers.junit.jupiter.Container;
38+
import org.testcontainers.junit.jupiter.Testcontainers;
39+
40+
import org.springframework.ai.document.Document;
41+
import org.springframework.ai.embedding.EmbeddingModel;
42+
import org.springframework.ai.ollama.OllamaEmbeddingModel;
43+
import org.springframework.ai.ollama.api.OllamaApi;
44+
import org.springframework.ai.ollama.api.OllamaModel;
45+
import org.springframework.ai.ollama.api.OllamaOptions;
46+
import org.springframework.beans.factory.annotation.Qualifier;
47+
import org.springframework.boot.SpringBootConfiguration;
48+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
49+
import org.springframework.context.annotation.Bean;
50+
import org.springframework.core.io.DefaultResourceLoader;
51+
52+
import static org.assertj.core.api.Assertions.assertThat;
53+
import static org.hamcrest.Matchers.hasSize;
54+
55+
@Testcontainers
56+
@EnabledIfEnvironmentVariable(named = "OLLAMA_TESTS_ENABLED", matches = "true")
57+
class OpenSearchVectorStoreWithOllamaIT {
58+
59+
@Container
60+
private static final OpensearchContainer<?> opensearchContainer = new OpensearchContainer<>(
61+
OpenSearchImage.DEFAULT_IMAGE);
62+
63+
private static final String DEFAULT = "cosinesimil";
64+
65+
private List<Document> documents = List.of(
66+
new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
67+
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
68+
new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
69+
70+
@BeforeAll
71+
public static void beforeAll() {
72+
Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS);
73+
Awaitility.setDefaultPollDelay(Duration.ZERO);
74+
Awaitility.setDefaultTimeout(Duration.ofMinutes(1));
75+
}
76+
77+
private String getText(String uri) {
78+
var resource = new DefaultResourceLoader().getResource(uri);
79+
try {
80+
return resource.getContentAsString(StandardCharsets.UTF_8);
81+
}
82+
catch (IOException e) {
83+
throw new RuntimeException(e);
84+
}
85+
}
86+
87+
private ApplicationContextRunner getContextRunner() {
88+
return new ApplicationContextRunner().withUserConfiguration(TestApplication.class);
89+
}
90+
91+
@BeforeEach
92+
void cleanDatabase() {
93+
getContextRunner().run(context -> {
94+
VectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
95+
vectorStore.delete(List.of("_all"));
96+
97+
VectorStore anotherVectorStore = context.getBean("anotherVectorStore", OpenSearchVectorStore.class);
98+
anotherVectorStore.delete(List.of("_all"));
99+
});
100+
}
101+
102+
@ParameterizedTest(name = "{0} : {displayName} ")
103+
@ValueSource(strings = { DEFAULT, "l1", "l2", "linf" })
104+
public void addAndSearchTest(String similarityFunction) {
105+
106+
getContextRunner().run(context -> {
107+
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
108+
109+
if (!DEFAULT.equals(similarityFunction)) {
110+
vectorStore.withSimilarityFunction(similarityFunction);
111+
}
112+
113+
vectorStore.add(this.documents);
114+
115+
Awaitility.await()
116+
.until(() -> vectorStore
117+
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
118+
hasSize(1));
119+
120+
List<Document> results = vectorStore
121+
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0));
122+
123+
assertThat(results).hasSize(1);
124+
Document resultDoc = results.get(0);
125+
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
126+
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
127+
assertThat(resultDoc.getMetadata()).hasSize(2);
128+
assertThat(resultDoc.getMetadata()).containsKey("meta2");
129+
assertThat(resultDoc.getMetadata()).containsKey("distance");
130+
131+
// Remove all documents from the store
132+
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
133+
134+
Awaitility.await()
135+
.until(() -> vectorStore
136+
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
137+
hasSize(0));
138+
});
139+
}
140+
141+
@SpringBootConfiguration
142+
public static class TestApplication {
143+
144+
@Bean
145+
@Qualifier("vectorStore")
146+
public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) {
147+
try {
148+
return new OpenSearchVectorStore(new OpenSearchClient(ApacheHttpClient5TransportBuilder
149+
.builder(HttpHost.create(opensearchContainer.getHttpHostAddress()))
150+
.build()), embeddingModel, true);
151+
}
152+
catch (URISyntaxException e) {
153+
throw new RuntimeException(e);
154+
}
155+
}
156+
157+
@Bean
158+
@Qualifier("anotherVectorStore")
159+
public OpenSearchVectorStore anotherVectorStore(EmbeddingModel embeddingModel) {
160+
try {
161+
return new OpenSearchVectorStore("another_index",
162+
new OpenSearchClient(ApacheHttpClient5TransportBuilder
163+
.builder(HttpHost.create(opensearchContainer.getHttpHostAddress()))
164+
.build()),
165+
embeddingModel, OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION,
166+
true);
167+
}
168+
catch (URISyntaxException e) {
169+
throw new RuntimeException(e);
170+
}
171+
}
172+
173+
@Bean
174+
public EmbeddingModel embeddingModel() {
175+
return OllamaEmbeddingModel.builder()
176+
.withOllamaApi(new OllamaApi())
177+
.withDefaultOptions(OllamaOptions.create()
178+
.withModel(OllamaModel.MXBAI_EMBED_LARGE)
179+
.withMainGPU(11)
180+
.withUseMMap(true)
181+
.withNumGPU(1))
182+
.build();
183+
}
184+
185+
}
186+
187+
}

0 commit comments

Comments
 (0)