Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.test.vectorstore;

import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;

import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;

/**
* Base test class for VectorStore implementations. Provides common test scenarios for
* delete operations.
*
* @author Soby Chacko
*/
public abstract class BaseVectorStoreTests {

/**
* Execute a test function with a configured VectorStore instance. This method is
* responsible for providing a properly initialized VectorStore within the appropriate
* Spring application context for testing.
* @param testFunction the consumer that executes test operations on the VectorStore
*/
protected abstract void executeTest(Consumer<VectorStore> testFunction);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A javadoc here on what to expect for the implementations would be helpful


protected Document createDocument(String country, Integer year) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("country", country);
if (year != null) {
metadata.put("year", year);
}
return new Document("The World is Big and Salvation Lurks Around the Corner", metadata);
}

protected List<Document> setupTestDocuments(VectorStore vectorStore) {
var doc1 = createDocument("BG", 2020);
var doc2 = createDocument("NL", null);
var doc3 = createDocument("BG", 2023);

List<Document> documents = List.of(doc1, doc2, doc3);
vectorStore.add(documents);

return documents;
}

private String normalizeValue(Object value) {
return value.toString().replaceAll("^\"|\"$", "").trim();
}

private void verifyDocumentsExist(VectorStore vectorStore, List<Document> documents) {
await().atMost(5, TimeUnit.SECONDS).pollInterval(Duration.ofMillis(500)).untilAsserted(() -> {
List<Document> results = vectorStore.similaritySearch(
SearchRequest.builder().query("The World").topK(documents.size()).similarityThresholdAll().build());
assertThat(results).hasSize(documents.size());
});
}

private void verifyDocumentsDeleted(VectorStore vectorStore, List<String> deletedIds) {
await().atMost(5, TimeUnit.SECONDS).pollInterval(Duration.ofMillis(500)).untilAsserted(() -> {
List<Document> results = vectorStore
.similaritySearch(SearchRequest.builder().query("The World").topK(10).similarityThresholdAll().build());

List<String> foundIds = results.stream().map(Document::getId).collect(Collectors.toList());

assertThat(foundIds).doesNotContainAnyElementsOf(deletedIds);
});
}

@Test
protected void deleteById() {
executeTest(vectorStore -> {
List<Document> documents = setupTestDocuments(vectorStore);
verifyDocumentsExist(vectorStore, documents);

List<String> idsToDelete = List.of(documents.get(0).getId(), documents.get(1).getId());
vectorStore.delete(idsToDelete);
verifyDocumentsDeleted(vectorStore, idsToDelete);

List<Document> results = vectorStore
.similaritySearch(SearchRequest.builder().query("The World").topK(5).similarityThresholdAll().build());

assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(documents.get(2).getId());
Map<String, Object> metadata = results.get(0).getMetadata();
assertThat(normalizeValue(metadata.get("country"))).isEqualTo("BG");
assertThat(normalizeValue(metadata.get("year"))).isEqualTo("2023");

vectorStore.delete(List.of(documents.get(2).getId()));
});
}

@Test
protected void deleteWithStringFilterExpression() {
executeTest(vectorStore -> {
List<Document> documents = setupTestDocuments(vectorStore);
verifyDocumentsExist(vectorStore, documents);

List<String> bgDocIds = documents.stream()
.filter(d -> "BG".equals(d.getMetadata().get("country")))
.map(Document::getId)
.collect(Collectors.toList());

vectorStore.delete("country == 'BG'");
verifyDocumentsDeleted(vectorStore, bgDocIds);

List<Document> results = vectorStore
.similaritySearch(SearchRequest.builder().query("The World").topK(5).similarityThresholdAll().build());

assertThat(results).hasSize(1);
assertThat(normalizeValue(results.get(0).getMetadata().get("country"))).isEqualTo("NL");

vectorStore.delete(List.of(documents.get(1).getId()));
});
}

@Test
protected void deleteByFilter() {
executeTest(vectorStore -> {
List<Document> documents = setupTestDocuments(vectorStore);
verifyDocumentsExist(vectorStore, documents);

List<String> bgDocIds = documents.stream()
.filter(d -> "BG".equals(d.getMetadata().get("country")))
.map(Document::getId)
.collect(Collectors.toList());

Filter.Expression filterExpression = new Filter.Expression(Filter.ExpressionType.EQ,
new Filter.Key("country"), new Filter.Value("BG"));

vectorStore.delete(filterExpression);
verifyDocumentsDeleted(vectorStore, bgDocIds);

List<Document> results = vectorStore
.similaritySearch(SearchRequest.builder().query("The World").topK(5).similarityThresholdAll().build());

assertThat(results).hasSize(1);
assertThat(normalizeValue(results.get(0).getMetadata().get("country"))).isEqualTo("NL");

vectorStore.delete(List.of(documents.get(1).getId()));
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import com.datastax.oss.driver.api.core.CqlSession;
Expand All @@ -40,8 +42,10 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.test.vectorstore.BaseVectorStoreTests;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore.SchemaColumn;
import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore.SchemaColumnTags;
import org.springframework.ai.vectorstore.filter.Filter;
Expand All @@ -64,7 +68,7 @@
* @since 1.0.0
*/
@Testcontainers
class CassandraVectorStoreIT {
class CassandraVectorStoreIT extends BaseVectorStoreTests {

@Container
static CassandraContainer<?> cassandraContainer = new CassandraContainer<>(CassandraImage.DEFAULT_IMAGE);
Expand Down Expand Up @@ -110,6 +114,24 @@ private static CassandraVectorStore createTestStore(ApplicationContext context,
return store;
}

@Override
protected void executeTest(Consumer<VectorStore> testFunction) {
contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
testFunction.accept(vectorStore);
});
}

@Override
protected Document createDocument(String country, Integer year) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("country", country);
if (year != null) {
metadata.put("year", year.shortValue());
}
return new Document("The World is Big and Salvation Lurks Around the Corner", metadata);
}

@Test
void ensureBeanGetsCreated() {
this.contextRunner.run(context -> {
Expand Down Expand Up @@ -422,7 +444,7 @@ void searchWithThreshold() {
}

@Test
void deleteByFilter() {
protected void deleteByFilter() {
this.contextRunner.run(context -> {
try (CassandraVectorStore store = createTestStore(context,
new SchemaColumn("country", DataTypes.TEXT, SchemaColumnTags.INDEXED),
Expand Down Expand Up @@ -458,7 +480,7 @@ void deleteByFilter() {
}

@Test
void deleteWithStringFilterExpression() {
protected void deleteWithStringFilterExpression() {
this.contextRunner.run(context -> {
try (CassandraVectorStore store = createTestStore(context,
new SchemaColumn("country", DataTypes.TEXT, SchemaColumnTags.INDEXED),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
Expand All @@ -33,6 +34,7 @@
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.test.vectorstore.BaseVectorStoreTests;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
Expand All @@ -51,7 +53,7 @@
*/
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class ChromaVectorStoreIT {
public class ChromaVectorStoreIT extends BaseVectorStoreTests {

@Container
static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE);
Expand All @@ -68,6 +70,14 @@ public class ChromaVectorStoreIT {
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression",
Collections.singletonMap("meta2", "meta2")));

@Override
protected void executeTest(Consumer<VectorStore> testFunction) {
contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
testFunction.accept(vectorStore);
});
}

@Test
public void addAndSearch() {
this.contextRunner.run(context -> {
Expand Down Expand Up @@ -168,69 +178,6 @@ public void addAndSearchWithFilters() {
});
}

@Test
public void deleteWithFilterExpression() {
this.contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);

// Create test documents with different metadata
var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
Map.of("country", "Bulgaria"));
var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
Map.of("country", "Netherlands"));

// Add documents to the store
vectorStore.add(List.of(bgDocument, nlDocument));

// Verify initial state
var request = SearchRequest.builder().query("The World").topK(5).build();
List<Document> results = vectorStore.similaritySearch(request);
assertThat(results).hasSize(2);

// Delete document with country = Bulgaria
Filter.Expression filterExpression = new Filter.Expression(Filter.ExpressionType.EQ,
new Filter.Key("country"), new Filter.Value("Bulgaria"));

vectorStore.delete(filterExpression);

// Verify Bulgaria document was deleted
results = vectorStore
.similaritySearch(SearchRequest.from(request).filterExpression("country == 'Bulgaria'").build());
assertThat(results).isEmpty();

// Verify Netherlands document still exists
results = vectorStore
.similaritySearch(SearchRequest.from(request).filterExpression("country == 'Netherlands'").build());
assertThat(results).hasSize(1);
assertThat(results.get(0).getMetadata().get("country")).isEqualTo("Netherlands");

// Clean up
vectorStore.delete(List.of(nlDocument.getId()));
});
}

@Test
public void deleteWithStringFilterExpression() {
this.contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);

var bgDocument = new Document("The World is Big", Map.of("country", "Bulgaria"));
var nlDocument = new Document("The World is Big", Map.of("country", "Netherlands"));
vectorStore.add(List.of(bgDocument, nlDocument));

var request = SearchRequest.builder().query("World").topK(5).build();
assertThat(vectorStore.similaritySearch(request)).hasSize(2);

vectorStore.delete("country == 'Bulgaria'");

var results = vectorStore.similaritySearch(request);
assertThat(results).hasSize(1);
assertThat(results.get(0).getMetadata().get("country")).isEqualTo("Netherlands");

vectorStore.delete(List.of(nlDocument.getId()));
});
}

@Test
public void documentUpdateTest() {

Expand Down
Loading