Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
* @author Soby Chacko
* @author Christian Tzolov
* @author Thomas Vitale
* @author inpink
* @since 1.0.0
*/
public class OpenSearchVectorStore extends AbstractObservationVectorStore implements InitializingBean {
Expand Down Expand Up @@ -178,6 +179,7 @@ public List<Document> similaritySearch(float[] embedding, int topK, double simil
Filter.Expression filterExpression) {
return similaritySearch(new org.opensearch.client.opensearch.core.SearchRequest.Builder()
.query(getOpenSearchSimilarityQuery(embedding, filterExpression))
.index(this.index)
.sort(sortOptionsBuilder -> sortOptionsBuilder
.score(scoreSortBuilder -> scoreSortBuilder.order(SortOrder.Desc)))
.size(topK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.awaitility.Awaitility;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
Expand All @@ -30,6 +31,7 @@
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
Expand Down Expand Up @@ -58,6 +60,7 @@
/**
* @author Jemin Huh
* @author Soby Chacko
* @author inpink
* @since 1.0.0
*/
@Testcontainers
Expand Down Expand Up @@ -99,8 +102,11 @@ private ApplicationContextRunner getContextRunner() {
@BeforeEach
void cleanDatabase() {
getContextRunner().run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
VectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
vectorStore.delete(List.of("_all"));

VectorStore anotherVectorStore = context.getBean("anotherVectorStore", OpenSearchVectorStore.class);
anotherVectorStore.delete(List.of("_all"));
});
}

Expand All @@ -109,7 +115,7 @@ void cleanDatabase() {
public void addAndSearchTest(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);

if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
Expand Down Expand Up @@ -148,7 +154,7 @@ public void addAndSearchTest(String similarityFunction) {
public void searchWithFilters(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);

if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
Expand Down Expand Up @@ -246,7 +252,7 @@ public void searchWithFilters(String similarityFunction) {
public void documentUpdateTest(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
}
Expand Down Expand Up @@ -302,7 +308,7 @@ public void documentUpdateTest(String similarityFunction) {
public void searchThresholdTest(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
}
Expand Down Expand Up @@ -343,11 +349,41 @@ public void searchThresholdTest(String similarityFunction) {
});
}

@Test
public void searchDocumentsInTwoIndicesTest() {
getContextRunner().run(context -> {
// given
OpenSearchVectorStore vectorStore1 = context.getBean("vectorStore", OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore2 = context.getBean("anotherVectorStore", OpenSearchVectorStore.class);

Document docInIndex1 = new Document("1", "Document in index 1", Map.of("meta", "index1"));
Document docInIndex2 = new Document("2", "Document in index 2", Map.of("meta", "index2"));

// when
vectorStore1.add(List.of(docInIndex1));
vectorStore2.add(List.of(docInIndex2));

List<Document> resultInIndex1 = vectorStore1
.similaritySearch(SearchRequest.query("Document in index 1").withTopK(1).withSimilarityThreshold(0));

List<Document> resultInIndex2 = vectorStore2
.similaritySearch(SearchRequest.query("Document in index 2").withTopK(1).withSimilarityThreshold(0));

// then
assertThat(resultInIndex1).hasSize(1);
assertThat(resultInIndex1.get(0).getId()).isEqualTo(docInIndex1.getId());

assertThat(resultInIndex2).hasSize(1);
assertThat(resultInIndex2.get(0).getId()).isEqualTo(docInIndex2.getId());
});
}

@SpringBootConfiguration
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
public static class TestApplication {

@Bean
@Qualifier("vectorStore")
public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) {
try {
return new OpenSearchVectorStore(new OpenSearchClient(ApacheHttpClient5TransportBuilder
Expand All @@ -359,6 +395,22 @@ public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) {
}
}

@Bean
@Qualifier("anotherVectorStore")
public OpenSearchVectorStore anotherVectorStore(EmbeddingModel embeddingModel) {
try {
return new OpenSearchVectorStore("another_index",
new OpenSearchClient(ApacheHttpClient5TransportBuilder
.builder(HttpHost.create(opensearchContainer.getHttpHostAddress()))
.build()),
embeddingModel, OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536,
true);
}
catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
Expand Down