From d262c4cd52db08f4d03e1b26eb6fb60039c1283b Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Tue, 26 Nov 2024 08:03:40 +0100 Subject: [PATCH 1/2] Support similarity scores in Document API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Document * Introduced “score” attribute in Document API. It stores the similarity score. * Consolidate “distance” metadata for Documents. It stores the distance measurement. * Adopted prefix-less naming convention in Document.Builder and deprecated old methods. * Deprecated the many overloaded Document constructors in favour of Document.Builder. Vector Stores * Every vector store implementation now configures a “score” attribute with the similarity score of the Document embedding. It also includes the “distance” metadata with the distance measurement. * Fixed error in Elasticsearch where distance and similarity were mixed up. * Added missing integration tests for SimpleVectorStore. * The Azure Vector Store and HanaDB Vector Store do not include those measurements because the product documentation do not include information about how the similarity score is returned, and without access to the cloud products I could not verify that via debugging. * Improved tests to actually assert the result of the similarity search based on the returned score. Signed-off-by: Thomas Vitale --- .../springframework/ai/document/Document.java | 192 ++++++++++++------ .../ai/document/DocumentMetadata.java | 50 +++++ .../ai/document/package-info.java | 22 ++ .../ai/vectorstore/SearchRequest.java | 8 +- .../ai/vectorstore/SimpleVectorStore.java | 34 +--- .../vectorstore/SimpleVectorStoreContent.java | 8 + .../ai/vectorstore/package-info.java | 22 ++ .../ai/document/DocumentBuilderTests.java | 84 ++++---- .../SimpleVectorStoreSimilarityTests.java | 4 +- spring-ai-integration-tests/pom.xml | 8 +- .../ai/integration/tests/TestApplication.java | 8 + .../vectorstore/SimpleVectorStoreIT.java | 122 +++++++++++ .../PineconeVectorStoreProperties.java | 4 +- .../PineconeVectorStorePropertiesTests.java | 4 +- .../ai/vectorstore/CosmosDBVectorStore.java | 3 +- .../ai/vectorstore/CosmosDBVectorStoreIT.java | 1 + .../ai/vectorstore/CosmosDbImage.java | 35 ++++ .../vectorstore/azure/AzureVectorStore.java | 17 +- .../vectorstore/azure/AzureVectorStoreIT.java | 21 +- .../ai/vectorstore/CassandraVectorStore.java | 12 +- .../CassandraRichSchemaVectorStoreIT.java | 24 +-- .../vectorstore/CassandraVectorStoreIT.java | 22 +- .../ai/vectorstore/ChromaVectorStore.java | 18 +- .../ai/vectorstore/ChromaVectorStoreIT.java | 18 +- .../ChromaVectorStoreObservationIT.java | 2 - .../ai/vectorstore/CoherenceVectorStore.java | 11 +- .../vectorstore/CoherenceVectorStoreIT.java | 21 +- .../vectorstore/ElasticsearchVectorStore.java | 10 +- .../ElasticsearchVectorStoreIT.java | 18 +- .../ai/vectorstore/GemFireVectorStore.java | 7 +- .../ai/vectorstore/GemFireVectorStoreIT.java | 20 +- .../ai/vectorstore/MilvusVectorStore.java | 12 +- .../ai/vectorstore/MilvusVectorStoreIT.java | 23 +-- .../vectorstore/MongoDBAtlasVectorStore.java | 16 +- .../MongoDBAtlasVectorStoreIT.java | 53 +++++ .../ai/vectorstore/Neo4jVectorStore.java | 11 +- .../ai/vectorstore/Neo4jVectorStoreIT.java | 21 +- .../ai/vectorstore/OpenSearchVectorStore.java | 6 +- .../vectorstore/OpenSearchVectorStoreIT.java | 18 +- .../ai/vectorstore/OracleVectorStore.java | 14 +- .../ai/vectorstore/OracleVectorStoreIT.java | 30 ++- .../ai/vectorstore/PgVectorStore.java | 16 +- .../ai/vectorstore/PgVectorStoreIT.java | 32 +-- .../ai/vectorstore/PineconeVectorStore.java | 15 +- .../ai/vectorstore/PineconeVectorStoreIT.java | 21 +- .../vectorstore/qdrant/QdrantVectorStore.java | 17 +- .../qdrant/QdrantVectorStoreIT.java | 18 +- .../ai/vectorstore/RedisVectorStore.java | 14 +- .../ai/vectorstore/RedisVectorStoreIT.java | 26 ++- .../ai/vectorstore/TypesenseVectorStore.java | 11 +- .../vectorstore/TypesenseVectorStoreIT.java | 18 +- .../ai/vectorstore/WeaviateVectorStore.java | 17 +- .../ai/vectorstore/WeaviateVectorStoreIT.java | 22 +- 53 files changed, 866 insertions(+), 395 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/document/DocumentMetadata.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/document/package-info.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/vectorstore/package-info.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/vectorstore/SimpleVectorStoreIT.java create mode 100644 vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDbImage.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java index 4d19c87ba5e..45c2a5602fc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -31,6 +32,7 @@ import org.springframework.ai.document.id.RandomIdGenerator; import org.springframework.ai.model.Media; import org.springframework.ai.model.MediaContent; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -61,7 +63,15 @@ public class Document implements MediaContent { * Metadata for the document. It should not be nested and values should be restricted * to string, int, float, boolean for simple use with Vector Dbs. */ - private Map metadata; + private final Map metadata; + + /** + * Measure of similarity between the document embedding and the query vector. The + * higher the score, the more they are similar. It's the opposite of the distance + * measure. + */ + @Nullable + private Double score; /** * Embedding of the document. Note: ephemeral field. @@ -80,31 +90,61 @@ public Document(@JsonProperty("content") String content) { this(content, new HashMap<>()); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String content, Map metadata) { this(content, metadata, new RandomIdGenerator()); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String content, Collection media, Map metadata) { this(new RandomIdGenerator().generateId(content, metadata), content, media, metadata); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String content, Map metadata, IdGenerator idGenerator) { this(idGenerator.generateId(content, metadata), content, metadata); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String id, String content, Map metadata) { this(id, content, List.of(), metadata); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String id, String content, Collection media, Map metadata) { - Assert.hasText(id, "id must not be null or empty"); - Assert.notNull(content, "content must not be null"); - Assert.notNull(metadata, "metadata must not be null"); + this(id, content, media, metadata, null); + } + + public Document(String id, String content, @Nullable Collection media, + @Nullable Map metadata, @Nullable Double score) { + Assert.hasText(id, "id cannot be null or empty"); + Assert.notNull(content, "content cannot be null"); + Assert.notNull(media, "media cannot be null"); + Assert.noNullElements(media, "media cannot have null elements"); + Assert.notNull(metadata, "metadata cannot be null"); + Assert.noNullElements(metadata.keySet(), "metadata cannot have null keys"); + Assert.noNullElements(metadata.values(), "metadata cannot have null values"); this.id = id; this.content = content; - this.media = media; - this.metadata = metadata; + this.media = media != null ? media : List.of(); + this.metadata = metadata != null ? metadata : new HashMap<>(); + this.score = score; } public static Builder builder() { @@ -149,6 +189,15 @@ public Map getMetadata() { return this.metadata; } + @Nullable + public Double getScore() { + return this.score; + } + + public void setScore(@Nullable Double score) { + this.score = score; + } + /** * Return the embedding that were calculated. * @deprecated We are considering getting rid of this, please comment on @@ -186,57 +235,24 @@ public void setContentFormatter(ContentFormatter contentFormatter) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.id == null) ? 0 : this.id.hashCode()); - result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode()); - result = prime * result + ((this.content == null) ? 0 : this.content.hashCode()); - return result; + return Objects.hash(id, content, media, metadata); } @Override - public boolean equals(Object obj) { - if (this == obj) { + public boolean equals(Object o) { + if (this == o) return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - Document other = (Document) obj; - if (this.id == null) { - if (other.id != null) { - return false; - } - } - else if (!this.id.equals(other.id)) { - return false; - } - if (this.metadata == null) { - if (other.metadata != null) { - return false; - } - } - else if (!this.metadata.equals(other.metadata)) { + if (o == null || getClass() != o.getClass()) return false; - } - if (this.content == null) { - if (other.content != null) { - return false; - } - } - else if (!this.content.equals(other.content)) { - return false; - } - return true; + Document document = (Document) o; + return Objects.equals(id, document.id) && Objects.equals(content, document.content) + && Objects.equals(media, document.media) && Objects.equals(metadata, document.metadata); } @Override public String toString() { - return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content - + '\'' + ", media=" + this.media + '}'; + return "Document{" + "id='" + id + '\'' + ", content='" + content + '\'' + ", media=" + media + ", metadata=" + + metadata + ", score=" + score + '}'; } public static class Builder { @@ -249,56 +265,102 @@ public static class Builder { private Map metadata = new HashMap<>(); + private float[] embedding = new float[0]; + + private Double score; + private IdGenerator idGenerator = new RandomIdGenerator(); - public Builder withIdGenerator(IdGenerator idGenerator) { - Assert.notNull(idGenerator, "idGenerator must not be null"); + public Builder idGenerator(IdGenerator idGenerator) { + Assert.notNull(idGenerator, "idGenerator cannot be null"); this.idGenerator = idGenerator; return this; } - public Builder withId(String id) { - Assert.hasText(id, "id must not be null or empty"); + public Builder id(String id) { + Assert.hasText(id, "id cannot be null or empty"); this.id = id; return this; } - public Builder withContent(String content) { - Assert.notNull(content, "content must not be null"); + public Builder content(String content) { this.content = content; return this; } - public Builder withMedia(List media) { - Assert.notNull(media, "media must not be null"); + public Builder media(List media) { this.media = media; return this; } - public Builder withMedia(Media media) { - Assert.notNull(media, "media must not be null"); - this.media.add(media); + public Builder media(Media... media) { + Assert.noNullElements(media, "media cannot contain null elements"); + this.media.addAll(List.of(media)); return this; } - public Builder withMetadata(Map metadata) { - Assert.notNull(metadata, "metadata must not be null"); + public Builder metadata(Map metadata) { this.metadata = metadata; return this; } - public Builder withMetadata(String key, Object value) { - Assert.notNull(key, "key must not be null"); - Assert.notNull(value, "value must not be null"); + public Builder metadata(String key, Object value) { this.metadata.put(key, value); return this; } + public Builder embedding(float[] embedding) { + this.embedding = embedding; + return this; + } + + public Builder score(Double score) { + this.score = score; + return this; + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withIdGenerator(IdGenerator idGenerator) { + return idGenerator(idGenerator); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withId(String id) { + return id(id); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withContent(String content) { + return content(content); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMedia(List media) { + return media(media); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMedia(Media media) { + return media(media); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMetadata(Map metadata) { + return metadata(metadata); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMetadata(String key, Object value) { + return metadata(key, value); + } + public Document build() { if (!StringUtils.hasText(this.id)) { this.id = this.idGenerator.generateId(this.content, this.metadata); } - return new Document(this.id, this.content, this.media, this.metadata); + var document = new Document(this.id, this.content, this.media, this.metadata, this.score); + document.setEmbedding(this.embedding); + return document; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentMetadata.java new file mode 100644 index 00000000000..0d2ee895554 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentMetadata.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.document; + +import org.springframework.ai.vectorstore.VectorStore; + +/** + * Common set of metadata keys used in {@link Document}s by {@link DocumentReader}s and + * {@link VectorStore}s. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public enum DocumentMetadata { + +// @formatter:off + + /** + * Measure of distance between the document embedding and the query vector. + * The lower the distance, the more they are similar. + * It's the opposite of the similarity score. + */ + DISTANCE("distance"); + + private final String value; + + DocumentMetadata(String value) { + this.value = value; + } + public String value() { + return this.value; + } + +// @formatter:on + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/document/package-info.java new file mode 100644 index 00000000000..fdd93626479 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.document; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java index a009e56f78b..fcf72b0e639 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java @@ -22,6 +22,7 @@ import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -30,6 +31,7 @@ * instance and then apply the 'with' methods to alter the default values. * * @author Christian Tzolov + * @author Thomas Vitale */ public final class SearchRequest { @@ -51,6 +53,7 @@ public final class SearchRequest { private double similarityThreshold = SIMILARITY_THRESHOLD_ACCEPT_ALL; + @Nullable private Filter.Expression filterExpression; private SearchRequest(String query) { @@ -186,7 +189,7 @@ public SearchRequest withSimilarityThresholdAll() { * filter criteria. The 'null' value stands for no expression filters. * @return this builder. */ - public SearchRequest withFilterExpression(Filter.Expression expression) { + public SearchRequest withFilterExpression(@Nullable Filter.Expression expression) { this.filterExpression = expression; return this; } @@ -225,7 +228,7 @@ public SearchRequest withFilterExpression(Filter.Expression expression) { * 'null' value stands for no expression filters. * @return this.builder */ - public SearchRequest withFilterExpression(String textExpression) { + public SearchRequest withFilterExpression(@Nullable String textExpression) { this.filterExpression = (textExpression != null) ? new FilterExpressionTextParser().parse(textExpression) : null; return this; @@ -243,6 +246,7 @@ public double getSimilarityThreshold() { return this.similarityThreshold; } + @Nullable public Filter.Expression getFilterExpression() { return this.filterExpression; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java index c9414392086..28cbfc2eb5c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java @@ -43,6 +43,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -68,6 +69,7 @@ * @author Christian Tzolov * @author Sebastien Deleuze * @author Ilayaperumal Gopinathan + * @author Thomas Vitale */ public class SimpleVectorStore extends AbstractObservationVectorStore { @@ -127,12 +129,11 @@ public List doSimilaritySearch(SearchRequest request) { float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery()); return this.store.values() .stream() - .map(entry -> new Similarity(entry, - EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) - .filter(s -> s.score >= request.getSimilarityThreshold()) - .sorted(Comparator.comparingDouble(s -> s.score).reversed()) + .map(content -> content + .toDocument(EmbeddingMath.cosineSimilarity(userQueryEmbedding, content.getEmbedding()))) + .filter(document -> document.getScore() >= request.getSimilarityThreshold()) + .sorted(Comparator.comparing(Document::getScore).reversed()) .limit(request.getTopK()) - .map(s -> s.getDocument()) .toList(); } @@ -235,28 +236,7 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); } - public static class Similarity { - - private SimpleVectorStoreContent content; - - private double score; - - public Similarity(SimpleVectorStoreContent content, double score) { - this.content = content; - this.score = score; - } - - Document getDocument() { - return Document.builder() - .withId(this.content.getId()) - .withContent(this.content.getContent()) - .withMetadata(this.content.getMetadata()) - .build(); - } - - } - - public final class EmbeddingMath { + public static final class EmbeddingMath { private EmbeddingMath() { throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java index 53058624eed..2d1590f681d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java @@ -25,6 +25,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.document.id.IdGenerator; import org.springframework.ai.document.id.RandomIdGenerator; import org.springframework.ai.model.Content; @@ -135,6 +137,12 @@ public float[] getEmbedding() { return Arrays.copyOf(this.embedding, this.embedding.length); } + public Document toDocument(Double score) { + var metadata = new HashMap<>(this.metadata); + metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - score); + return Document.builder().id(this.id).content(this.content).metadata(metadata).score(score).build(); + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/package-info.java new file mode 100644 index 00000000000..3edee23fc81 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.vectorstore; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java index ebaeef38905..2408aebf8b9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java @@ -42,13 +42,11 @@ private static List getMediaList() { URL mediaUrl2 = new URL("http://type2"); Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1); Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2); - List mediaList = List.of(media1, media2); - return mediaList; + return List.of(media1, media2); } catch (MalformedURLException e) { throw new RuntimeException(e); } - } @BeforeEach @@ -58,32 +56,26 @@ void setUp() { @Test void testWithIdGenerator() { - IdGenerator mockGenerator = new IdGenerator() { - - @Override - public String generateId(Object... contents) { - return "mockedId"; - } - }; + IdGenerator mockGenerator = contents -> "mockedId"; - Document.Builder result = this.builder.withIdGenerator(mockGenerator); + Document.Builder result = this.builder.idGenerator(mockGenerator); assertThat(result).isSameAs(this.builder); - Document document = result.withContent("Test content").withMetadata("key", "value").build(); + Document document = result.content("Test content").metadata("key", "value").build(); assertThat(document.getId()).isEqualTo("mockedId"); } @Test void testWithIdGeneratorNull() { - assertThatThrownBy(() -> this.builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("idGenerator must not be null"); + assertThatThrownBy(() -> this.builder.idGenerator(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("idGenerator cannot be null"); } @Test void testWithId() { - Document.Builder result = this.builder.withId("testId"); + Document.Builder result = this.builder.id("testId"); assertThat(result).isSameAs(this.builder); assertThat(result.build().getId()).isEqualTo("testId"); @@ -91,16 +83,16 @@ void testWithId() { @Test void testWithIdNullOrEmpty() { - assertThatThrownBy(() -> this.builder.withId(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("id must not be null or empty"); + assertThatThrownBy(() -> this.builder.id(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id cannot be null or empty"); - assertThatThrownBy(() -> this.builder.withId("")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("id must not be null or empty"); + assertThatThrownBy(() -> this.builder.id("").build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id cannot be null or empty"); } @Test void testWithContent() { - Document.Builder result = this.builder.withContent("Test content"); + Document.Builder result = this.builder.content("Test content"); assertThat(result).isSameAs(this.builder); assertThat(result.build().getContent()).isEqualTo("Test content"); @@ -108,14 +100,14 @@ void testWithContent() { @Test void testWithContentNull() { - assertThatThrownBy(() -> this.builder.withContent(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("content must not be null"); + assertThatThrownBy(() -> this.builder.content(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("content cannot be null"); } @Test void testWithMediaList() { List mediaList = getMediaList(); - Document.Builder result = this.builder.withMedia(mediaList); + Document.Builder result = this.builder.media(mediaList); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMedia()).isEqualTo(mediaList); @@ -123,9 +115,9 @@ void testWithMediaList() { @Test void testWithMediaListNull() { - assertThatThrownBy(() -> this.builder.withMedia((List) null)) + assertThatThrownBy(() -> this.builder.media((List) null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("media must not be null"); + .hasMessageContaining("media cannot be null"); } @Test @@ -133,7 +125,7 @@ void testWithMediaSingle() throws MalformedURLException { URL mediaUrl = new URL("http://test"); Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl); - Document.Builder result = this.builder.withMedia(media); + Document.Builder result = this.builder.media(media); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMedia()).contains(media); @@ -141,8 +133,8 @@ void testWithMediaSingle() throws MalformedURLException { @Test void testWithMediaSingleNull() { - assertThatThrownBy(() -> this.builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("media must not be null"); + assertThatThrownBy(() -> this.builder.media((Media) null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("media cannot contain null elements"); } @Test @@ -150,7 +142,7 @@ void testWithMetadataMap() { Map metadata = new HashMap<>(); metadata.put("key1", "value1"); metadata.put("key2", 2); - Document.Builder result = this.builder.withMetadata(metadata); + Document.Builder result = this.builder.metadata(metadata); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMetadata()).isEqualTo(metadata); @@ -158,47 +150,51 @@ void testWithMetadataMap() { @Test void testWithMetadataMapNull() { - assertThatThrownBy(() -> this.builder.withMetadata((Map) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("metadata must not be null"); + assertThatThrownBy(() -> this.builder.metadata(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata cannot be null"); } @Test void testWithMetadataKeyValue() { - Document.Builder result = this.builder.withMetadata("key", "value"); + Document.Builder result = this.builder.metadata("key", "value"); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMetadata()).containsEntry("key", "value"); } @Test - void testWithMetadataKeyValueNull() { - assertThatThrownBy(() -> this.builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("key must not be null"); + void testWithMetadataKeyNull() { + assertThatThrownBy(() -> this.builder.metadata(null, "value").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata cannot have null keys"); + } - assertThatThrownBy(() -> this.builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("value must not be null"); + @Test + void testWithMetadataValueNull() { + assertThatThrownBy(() -> this.builder.metadata("key", null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata cannot have null values"); } @Test void testBuildWithoutId() { - Document document = this.builder.withContent("Test content").build(); + Document document = this.builder.content("Test content").build(); assertThat(document.getId()).isNotNull().isNotEmpty(); assertThat(document.getContent()).isEqualTo("Test content"); } @Test - void testBuildWithAllProperties() throws MalformedURLException { + void testBuildWithAllProperties() { List mediaList = getMediaList(); Map metadata = new HashMap<>(); metadata.put("key", "value"); - Document document = this.builder.withId("customId") - .withContent("Test content") - .withMedia(mediaList) - .withMetadata(metadata) + Document document = this.builder.id("customId") + .content("Test content") + .media(mediaList) + .metadata(metadata) .build(); assertThat(document.getId()).isEqualTo("customId"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java index 3447e5f7cce..499d0269af1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java @@ -27,6 +27,7 @@ /** * @author Ilayaperumal Gopinathan + * @author Thomas Vitale */ public class SimpleVectorStoreSimilarityTests { @@ -38,8 +39,7 @@ public void testSimilarity() { SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata, testEmbedding); - SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d); - Document document = similarity.getDocument(); + Document document = storeContent.toDocument(0.6); assertThat(document).isNotNull(); assertThat(document.getId()).isEqualTo("1"); assertThat(document.getContent()).isEqualTo("hello, how are you?"); diff --git a/spring-ai-integration-tests/pom.xml b/spring-ai-integration-tests/pom.xml index 3ca4c4ace4d..b3cbb4cde7d 100644 --- a/spring-ai-integration-tests/pom.xml +++ b/spring-ai-integration-tests/pom.xml @@ -54,7 +54,6 @@ test - org.springframework.ai spring-ai-openai-spring-boot-starter @@ -76,6 +75,13 @@ test + + org.springframework.ai + spring-ai-test + ${project.parent.version} + test + + org.springframework.boot spring-boot-testcontainers diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java index 7b3f01292b0..df17af1fd44 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java @@ -16,7 +16,10 @@ package org.springframework.ai.integration.tests; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.SimpleVectorStore; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; /** @@ -28,4 +31,9 @@ @Import(TestcontainersConfiguration.class) public class TestApplication { + @Bean + SimpleVectorStore simpleVectorStore(EmbeddingModel embeddingModel) { + return new SimpleVectorStore(embeddingModel); + } + } diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/vectorstore/SimpleVectorStoreIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/vectorstore/SimpleVectorStoreIT.java new file mode 100644 index 00000000000..c6bb774db2f --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/vectorstore/SimpleVectorStoreIT.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.vectorstore; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.DefaultResourceLoader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link SimpleVectorStore}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +public class SimpleVectorStoreIT { + + @Autowired + private SimpleVectorStore vectorStore; + + List documents = List.of( + Document.builder() + .id("471a8c78-549a-4b2c-bce5-ef3ae6579be3") + .content(getText("classpath:/test/data/spring.ai.txt")) + .metadata(Map.of("meta1", "meta1")) + .build(), + Document.builder() + .id("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f") + .content(getText("classpath:/test/data/time.shelter.txt")) + .metadata(Map.of()) + .build(), + Document.builder() + .id("d0237682-1150-44ff-b4d2-1be9b1731ee5") + .content(getText("classpath:/test/data/great.depression.txt")) + .metadata(Map.of("meta2", "meta2")) + .build()); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @AfterEach + void setUp() { + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); + } + + @Test + public void searchWithThreshold() { + Document document = Document.builder() + .id(UUID.randomUUID().toString()) + .content("Spring AI rocks!!") + .metadata("meta1", "meta1") + .build(); + + vectorStore.add(List.of(document)); + + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKey("meta1"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + + Document sameIdDocument = Document.builder() + .id(document.getId()) + .content("The World is Big and Salvation Lurks Around the Corner") + .metadata("meta2", "meta2") + .build(); + + vectorStore.add(List.of(sameIdDocument)); + + results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); + + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + + vectorStore.delete(List.of(document.getId())); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java index a40804a6500..bf9f2e9cfd0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java @@ -18,6 +18,7 @@ import java.time.Duration; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.vectorstore.PineconeVectorStore; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -25,6 +26,7 @@ * Configuration properties for Pinecone Vector Store. * * @author Christian Tzolov + * @author Thomas Vitale */ @ConfigurationProperties(PineconeVectorStoreProperties.CONFIG_PREFIX) public class PineconeVectorStoreProperties { @@ -43,7 +45,7 @@ public class PineconeVectorStoreProperties { private String contentFieldName = PineconeVectorStore.CONTENT_FIELD_NAME; - private String distanceMetadataFieldName = PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME; + private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value(); private Duration serverSideTimeout = Duration.ofSeconds(20); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java index ce450bac029..3083442f87c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java @@ -20,12 +20,14 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.vectorstore.PineconeVectorStore; import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov + * @author Thomas Vitale */ public class PineconeVectorStorePropertiesTests { @@ -39,7 +41,7 @@ public void defaultValues() { assertThat(props.getIndexName()).isNull(); assertThat(props.getServerSideTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(props.getContentFieldName()).isEqualTo(PineconeVectorStore.CONTENT_FIELD_NAME); - assertThat(props.getDistanceMetadataFieldName()).isEqualTo(PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME); + assertThat(props.getDistanceMetadataFieldName()).isEqualTo(DocumentMetadata.DISTANCE.value()); } @Test diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java index c0ad68e3600..11dc7006731 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java @@ -74,6 +74,7 @@ * * @author Theo van Kraay * @author Soby Chacko + * @author Thomas Vitale * @since 1.0.0 */ public class CosmosDBVectorStore extends AbstractObservationVectorStore implements AutoCloseable { @@ -338,7 +339,7 @@ public List doSimilaritySearch(SearchRequest request) { .block(); // Convert JsonNode to Document List docs = documents.stream() - .map(doc -> new Document(doc.get("id").asText(), doc.get("content").asText(), new HashMap<>())) + .map(doc -> Document.builder().id(doc.get("id").asText()).content(doc.get("content").asText()).build()) .collect(Collectors.toList()); return docs != null ? docs : List.of(); diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java index 97ac8081c74..fd8b044c0f5 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java @@ -42,6 +42,7 @@ /** * @author Theo van Kraay + * @author Thomas Vitale * @since 1.0.0 */ @EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+") diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDbImage.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDbImage.java new file mode 100644 index 00000000000..6bbbdc71cd4 --- /dev/null +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDbImage.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import org.testcontainers.utility.DockerImageName; + +/** + * @author Thomas Vitale + */ +public final class CosmosDbImage { + + // It must always be "latest" or else Azure locks the image after a while. See: + // https://github.com/Azure/azure-cosmos-db-emulator-docker/issues/60 + public static final DockerImageName DEFAULT_IMAGE = DockerImageName + .parse("mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:latest"); + + private CosmosDbImage() { + + } + +} diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index 6200738b928..f1893bd20f8 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -47,6 +47,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -96,8 +97,6 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements private static final String METADATA_FIELD_NAME = "metadata"; - private static final String DISTANCE_METADATA_FIELD_NAME = "distance"; - private static final int DEFAULT_TOP_K = 4; private static final Double DEFAULT_SIMILARITY_THRESHOLD = 0.0; @@ -321,13 +320,15 @@ public List doSimilaritySearch(SearchRequest request) { }) : Map.of(); - metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore()); - - final Document doc = new Document(entry.id(), entry.content(), metadata); - doc.setEmbedding(EmbeddingUtils.toPrimitive(entry.embedding())); - - return doc; + metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - result.getScore()); + return Document.builder() + .id(entry.id()) + .content(entry.content) + .metadata(metadata) + .score(result.getScore()) + .embedding(EmbeddingUtils.toPrimitive(entry.embedding)) + .build(); }) .collect(Collectors.toList()); } diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index dc87de60fee..50911ad10ce 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -35,6 +35,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.SearchRequest; @@ -52,6 +53,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") @@ -103,7 +105,7 @@ public void addAndSearchTest() { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -224,7 +226,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -245,7 +247,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -271,21 +273,22 @@ public void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 3ad7f5a916d..af105efc035 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -45,6 +45,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -106,8 +107,6 @@ */ public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable { - public static final String SIMILARITY_FIELD_NAME = "similarity_score"; - public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates"; public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search"; @@ -252,14 +251,19 @@ public List doSimilaritySearch(SearchRequest request) { break; } Map docFields = new HashMap<>(); - docFields.put(SIMILARITY_FIELD_NAME, score); + docFields.put(DocumentMetadata.DISTANCE.value(), 1 - score); for (var metadata : this.conf.schema.metadataColumns()) { var value = row.get(metadata.name(), metadata.javaType()); if (null != value) { docFields.put(metadata.name(), value); } } - Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields); + Document doc = Document.builder() + .id(getDocumentId(row)) + .content(row.getString(this.conf.schema.content())) + .metadata(docFields) + .score((double) score) + .build(); if (this.conf.returnEmbeddings) { doc.setEmbedding(EmbeddingUtils diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java index e3a3ee65feb..4a31c8dbe13 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -226,8 +227,7 @@ void addAndSearch() { assertThat(resultDoc.getMetadata()).hasSize(3); - assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", - CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value()); // Remove all documents from the createStore store.delete(documents.stream().map(doc -> doc.getId()).toList()); @@ -494,8 +494,7 @@ void documentUpdate() { assertThat(resultDoc.getId()).isNotEqualTo(sameIdDocument.getId()); assertThat(resultDoc.getContent()).doesNotContain(newContent); - assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", - CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value()); } }); } @@ -509,16 +508,15 @@ void searchWithThreshold() { List fullResult = store .similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME)) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = store.similaritySearch( - SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThreshold(threshold)); + List results = store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY) + .withTopK(5) + .withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -526,8 +524,8 @@ void searchWithThreshold() { assertThat(resultDoc.getContent()).contains(URANUS_ORBIT_QUERY); - assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", - CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); } }); } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java index e17091d5ea0..13f7e7d7265 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java @@ -30,6 +30,7 @@ import com.datastax.oss.driver.api.core.type.DataTypes; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -138,7 +139,7 @@ void addAndSearch() { "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store store.delete(documents().stream().map(doc -> doc.getId()).toList()); @@ -174,7 +175,7 @@ void addAndSearchReturnEmbeddings() { "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(1); - assertThat(resultDoc.getMetadata()).containsKey(CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store store.delete(documents().stream().map(doc -> doc.getId()).toList()); @@ -359,7 +360,7 @@ void documentUpdate() { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); store.delete(List.of(document.getId())); } @@ -375,16 +376,14 @@ void searchWithThreshold() { List fullResult = store .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME)) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = store - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(threshold)); + List results = store.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -393,7 +392,8 @@ void searchWithThreshold() { assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); } }); } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 49deb147d65..eafe1e37ad4 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -32,6 +32,7 @@ import org.springframework.ai.chroma.ChromaApi.DeleteEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.Embedding; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -58,11 +59,10 @@ * @author Fu Cheng * @author Sebastien Deleuze * @author Soby Chacko + * @author Thomas Vitale */ public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean { - public static final String DISTANCE_FIELD_NAME = "distance"; - public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection"; private final EmbeddingModel embeddingModel; @@ -192,9 +192,14 @@ public Optional doDelete(@NonNull List idList) { if (metadata == null) { metadata = new HashMap<>(); } - metadata.put(DISTANCE_FIELD_NAME, distance); - Document document = new Document(id, content, metadata); - document.setEmbedding(chromaEmbedding.embedding()); + metadata.put(DocumentMetadata.DISTANCE.value(), distance); + Document document = Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .embedding(chromaEmbedding.embedding()) + .score(1.0 - distance) + .build(); responseDocuments.add(document); } } @@ -244,8 +249,7 @@ public void afterPropertiesSet() throws Exception { @NonNull String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName) .withDimensions(this.embeddingModel.dimensions()) - .withCollectionName(this.collectionName + ":" + this.collectionId) - .withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null); + .withCollectionName(this.collectionName + ":" + this.collectionId); } public static class Builder { diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java index a2f6f000093..c3082028483 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -81,7 +82,7 @@ public void addAndSearch() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store assertThat(vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList())) @@ -179,7 +180,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -194,7 +195,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -213,13 +214,13 @@ public void searchThresholdTest() { var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -227,7 +228,8 @@ public void searchThresholdTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java index cd9e973b543..72fc9238e02 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java @@ -107,7 +107,6 @@ void observationVectorStoreAddAndQueryOperations() { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), "TestCollection:" + vectorStore.getCollectionId()) .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString()) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance") .doesNotHaveHighCardinalityKeyValueWithKey( HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString()) .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString()) @@ -141,7 +140,6 @@ void observationVectorStoreAddAndQueryOperations() { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), "TestCollection:" + vectorStore.getCollectionId()) .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString()) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance") .doesNotHaveHighCardinalityKeyValueWithKey( HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString(), "1") diff --git a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java index e9c64e79461..29ee2bfdff2 100644 --- a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java +++ b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java @@ -37,6 +37,7 @@ import com.tangosol.util.Filter; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.beans.factory.InitializingBean; @@ -62,6 +63,7 @@ * * * @author Aleks Seovic + * @author Thomas Vitale * @since 1.0.0 */ public class CoherenceVectorStore implements VectorStore, InitializingBean { @@ -211,8 +213,13 @@ public List similaritySearch(SearchRequest request) { if (this.distanceType != DistanceType.COSINE || (1 - r.getDistance()) >= request.getSimilarityThreshold()) { DocumentChunk.Id id = r.getKey(); DocumentChunk chunk = r.getValue(); - chunk.metadata().put("distance", r.getDistance()); - documents.add(new Document(id.docId(), chunk.text(), chunk.metadata())); + chunk.metadata().put(DocumentMetadata.DISTANCE.value(), r.getDistance()); + documents.add(Document.builder() + .id(id.docId()) + .content(chunk.text()) + .metadata(chunk.metadata()) + .score(1 - r.getDistance()) + .build()); } } return documents; diff --git a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java index 8b1a0c52bf6..776d449b518 100644 --- a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java +++ b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java @@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; @@ -125,7 +126,7 @@ public void addAndSearch(CoherenceVectorStore.DistanceType distanceType, Coheren assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -223,7 +224,7 @@ public void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -236,7 +237,7 @@ public void documentUpdate() { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName()); }); @@ -257,18 +258,18 @@ public void searchWithThreshold() { assertThat(isSortedByDistance(fullResult)).isTrue(); - List distances = fullResult.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - double threshold = 1d - (distances.get(0) + distances.get(1)) / 2f; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName()); }); @@ -276,7 +277,7 @@ public void searchWithThreshold() { private static boolean isSortedByDistance(final List documents) { final List distances = documents.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) + .map(doc -> (Double) doc.getMetadata().get(DocumentMetadata.DISTANCE.value())) .toList(); if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 9c058936b44..895c1f924c4 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -40,6 +40,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -210,20 +211,23 @@ private String getElasticsearchQueryString(Filter.Expression filterExpression) { private Document toDocument(Hit hit) { Document document = hit.source(); - document.getMetadata().put("distance", calculateDistance(hit.score().floatValue())); + if (hit.score() != null) { + document.getMetadata().put(DocumentMetadata.DISTANCE.value(), 1 - normalizeSimilarityScore(hit.score())); + document.setScore(normalizeSimilarityScore(hit.score())); + } return document; } // more info on score/distance calculation // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search - private float calculateDistance(Float score) { + private double normalizeSimilarityScore(double score) { switch (this.options.getSimilarity()) { case l2_norm: // the returned value of l2_norm is the opposite of the other functions // (closest to zero means more accurate), so to make it consistent // with the other functions the reverse is returned applying a "1-" // to the standard transformation - return (float) (1 - (java.lang.Math.sqrt((1 / score) - 1))); + return (1 - (java.lang.Math.sqrt((1 / score) - 1))); // cosine and dot_product default: return (2 * score) - 1; diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index c262c9a4af8..2372a5d9c33 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -42,6 +42,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.elasticsearch.ElasticsearchContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -165,7 +166,7 @@ public void addAndSearchTest(String similarityFunction) { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); @@ -299,7 +300,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); @@ -318,7 +319,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -343,21 +344,22 @@ public void searchThresholdTest(String similarityFunction) { List fullResult = vectorStore.similaritySearch(query); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float thresholdResult = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(thresholdResult)); + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index 753eb1e7bdb..24e47aedf33 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -30,6 +30,7 @@ import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import reactor.util.annotation.NonNull; import org.springframework.ai.document.Document; @@ -172,8 +173,6 @@ public String[] getFields() { // Query Defaults private static final String QUERY = "/query"; - private static final String DISTANCE_METADATA_FIELD_NAME = "distance"; - /** * Initializes the GemFireVectorStore after properties are set. This method is called * after all bean properties have been set and allows the bean to perform any @@ -271,9 +270,9 @@ public List doSimilaritySearch(SearchRequest request) { metadata = new HashMap<>(); metadata.put(DOCUMENT_FIELD, "--Deleted--"); } - metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - r.score); String content = (String) metadata.remove(DOCUMENT_FIELD); - return new Document(r.key, content, metadata); + return Document.builder().id(r.key).content(content).metadata(metadata).score((double) r.score).build(); }) .collectList() .onErrorMap(WebClientException.class, this::handleHttpClientException) diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java index fd93afe0582..1430520fe2c 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java @@ -34,6 +34,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.boot.SpringBootConfiguration; @@ -134,7 +135,7 @@ public void addAndSearchTest() { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939)" + " was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); }); } @@ -156,7 +157,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks " + "Around the Corner", @@ -171,7 +172,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation" + " Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); }); } @@ -191,12 +192,12 @@ public void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); - assertThat(distances).hasSize(3); + List scores = fullResult.stream().map(Document::getScore).toList(); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold)); + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; + List results = vectorStore.similaritySearch( + SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); @@ -204,7 +205,8 @@ public void searchThresholdTest() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression " + "(1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java index 9ed8d8b0157..18c5f26ea02 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java @@ -54,6 +54,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -97,7 +98,7 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements public static final String EMBEDDING_FIELD_NAME = "embedding"; // Metadata, automatically assigned by Milvus. - public static final String DISTANCE_FIELD_NAME = "distance"; + private static final String DISTANCE_FIELD_NAME = "distance"; private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); @@ -258,13 +259,18 @@ public List doSimilaritySearch(SearchRequest request) { try { metadata = (JSONObject) rowRecord.get(this.config.metadataFieldName); // inject the distance into the metadata. - metadata.put(DISTANCE_FIELD_NAME, 1 - getResultSimilarity(rowRecord)); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - getResultSimilarity(rowRecord)); } catch (ParamException e) { // skip the ParamException if metadata doesn't exist for the custom // collection } - return new Document(docId, content, (metadata != null) ? metadata.getInnerMap() : Map.of()); + return Document.builder() + .id(docId) + .content(content) + .metadata((metadata != null) ? metadata.getInnerMap() : Map.of()) + .score((double) getResultSimilarity(rowRecord)) + .build(); }) .toList(); } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java index 98a88e7b7d1..757b5c1b154 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.milvus.MilvusContainer; @@ -106,7 +107,7 @@ public void addAndSearch(String metricType) { assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -200,7 +201,7 @@ public void documentUpdate(String metricType) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -215,7 +216,7 @@ public void documentUpdate(String metricType) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -238,24 +239,22 @@ public void searchWithThreshold(String metricType) { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); - + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java index 2e40db851c6..3ab60f4cc67 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java @@ -26,6 +26,7 @@ import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -172,12 +173,17 @@ private org.bson.Document createSearchIndexDefinition() { private Document mapMongoDocument(org.bson.Document mongoDocument, float[] queryEmbedding) { String id = mongoDocument.getString(ID_FIELD_NAME); String content = mongoDocument.getString(CONTENT_FIELD_NAME); + double score = mongoDocument.getDouble(SCORE_FIELD_NAME); Map metadata = mongoDocument.get(METADATA_FIELD_NAME, org.bson.Document.class); - - Document document = new Document(id, content, metadata); - document.setEmbedding(queryEmbedding); - - return document; + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - score); + + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score(score) + .embedding(queryEmbedding) + .build(); } @Override diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java index d908e843648..4eced845ed8 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java @@ -16,6 +16,8 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -27,6 +29,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.DocumentMetadata; +import org.springframework.core.io.DefaultResourceLoader; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; @@ -198,6 +202,55 @@ void searchWithFilters() { }); } + @Test + public void searchWithThreshold() { + this.contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + var documents = List.of( + new Document("471a8c78-549a-4b2c-bce5-ef3ae6579be3", getText("classpath:/test/data/spring.ai.txt"), + Map.of("meta1", "meta1")), + new Document("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f", + getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("d0237682-1150-44ff-b4d2-1be9b1731ee5", + getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + vectorStore.add(documents); + Thread.sleep(5000); // Await a second for the document to be indexed + + List fullResult = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); + assertThat(fullResult).hasSize(3); + + List scores = fullResult.stream().map(Document::getScore).toList(); + + assertThat(scores).hasSize(3); + + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; + + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); + + }); + } + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @SpringBootConfiguration @EnableAutoConfiguration public static class TestApplication { diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java index 0493d893ca0..97a74a175a7 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java @@ -29,6 +29,7 @@ import org.neo4j.driver.Values; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -224,15 +225,19 @@ private Document recordToDocument(org.neo4j.driver.Record neoRecord) { var node = neoRecord.get("node").asNode(); var score = neoRecord.get("score").asFloat(); var metaData = new HashMap(); - metaData.put("distance", 1 - score); + metaData.put(DocumentMetadata.DISTANCE.value(), 1 - score); node.keys().forEach(key -> { if (key.startsWith("metadata.")) { metaData.put(key.substring(key.indexOf(".") + 1), node.get(key).asObject()); } }); - return new Document(node.get(this.config.idProperty).asString(), node.get("text").asString(), - Map.copyOf(metaData)); + return Document.builder() + .id(node.get(this.config.idProperty).asString()) + .content(node.get("text").asString()) + .metadata(Map.copyOf(metaData)) + .score((double) score) + .build(); } @Override diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java index a707a30622a..b5bc62ef9f5 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java @@ -28,6 +28,7 @@ import org.neo4j.driver.AuthTokens; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.Neo4jContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -91,7 +92,7 @@ void addAndSearchTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); @@ -203,7 +204,7 @@ void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -218,7 +219,7 @@ void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); }); } @@ -235,14 +236,14 @@ void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -250,8 +251,8 @@ void searchThresholdTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); - + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index 653a78df78a..6199040f85b 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -40,6 +40,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -231,7 +232,10 @@ private List similaritySearch(org.opensearch.client.opensearch.core.Se private Document toDocument(Hit hit) { Document document = hit.source(); - document.getMetadata().put("distance", 1 - hit.score().floatValue()); + if (hit.score() != null) { + document.setScore(hit.score()); + document.getMetadata().put(DocumentMetadata.DISTANCE.value(), 1 - hit.score().floatValue()); + } return document; } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java index 6652ecd7882..5a4a5b20ac9 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java @@ -39,6 +39,7 @@ import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.opensearch.testcontainers.OpensearchContainer; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -144,7 +145,7 @@ public void addAndSearchTest(String similarityFunction) { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); @@ -281,7 +282,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); @@ -300,7 +301,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -330,21 +331,22 @@ public void searchThresholdTest(String similarityFunction) { List fullResult = vectorStore.similaritySearch(query); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold)); + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java index a9e3b63eb79..f519024dccc 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java @@ -39,6 +39,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -80,6 +81,7 @@ * @author Loïc Lefèvre * @author Christian Tzolov * @author Soby Chacko + * @author Thomas Vitale */ public class OracleVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -649,12 +651,16 @@ private static class DocumentRowMapper implements RowMapper { @Override public Document mapRow(ResultSet rs, int rowNum) throws SQLException { final Map metadata = getMap(rs.getObject(3, OracleJsonValue.class)); - metadata.put("distance", rs.getDouble(5)); + metadata.put(DocumentMetadata.DISTANCE.value(), rs.getDouble(5)); - final Document document = new Document(rs.getString(1), rs.getString(2), metadata); final float[] embedding = rs.getObject(4, float[].class); - document.setEmbedding(embedding); - return document; + return Document.builder() + .id(rs.getString(1)) + .content(rs.getString(2)) + .metadata(metadata) + .score(1 - rs.getDouble(5)) + .embedding(embedding) + .build(); } private Map getMap(OracleJsonValue value) { diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java index 35727822d18..2c3fcee85c3 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java @@ -32,6 +32,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.oracle.OracleContainer; @@ -94,21 +95,19 @@ private static void dropTable(ApplicationContext context, String tableName) { jdbcTemplate.execute("DROP TABLE IF EXISTS " + tableName + " PURGE"); } - private static boolean isSortedByDistance(final List documents) { - final List distances = documents.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + private static boolean isSortedBySimilarity(final List documents) { + final List scores = documents.stream().map(Document::getScore).toList(); - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + if (CollectionUtils.isEmpty(scores) || scores.size() == 1) { return true; } - Iterator iter = distances.iterator(); + Iterator iter = scores.iterator(); Double current; Double previous = iter.next(); while (iter.hasNext()) { current = iter.next(); - if (previous > current) { + if (previous < current) { return false; } previous = current; @@ -134,7 +133,7 @@ public void addAndSearch(String distanceType) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -243,7 +242,7 @@ public void documentUpdate(String distanceType) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -256,7 +255,7 @@ public void documentUpdate(String distanceType) { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); dropTable(context, ((OracleVectorStore) vectorStore).getTableName()); }); @@ -279,20 +278,19 @@ public void searchWithThreshold(String distanceType) { assertThat(fullResult).hasSize(3); - assertThat(isSortedByDistance(fullResult)).isTrue(); + assertThat(isSortedBySimilarity(fullResult)).isTrue(); - List distances = fullResult.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - double threshold = (distances.get(0) + distances.get(1)) / 2d; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2d; List results = vectorStore.similaritySearch( - SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1d - threshold)); + SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); dropTable(context, ((OracleVectorStore) vectorStore).getTableName()); }); diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 69cd79b41c4..f74e3fe94bb 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -502,12 +503,15 @@ public Document mapRow(ResultSet rs, int rowNum) throws SQLException { Float distance = rs.getFloat(COLUMN_DISTANCE); Map metadata = toMap(pgMetadata); - metadata.put(COLUMN_DISTANCE, distance); - - Document document = new Document(id, content, metadata); - document.setEmbedding(toFloatArray(embedding)); - - return document; + metadata.put(DocumentMetadata.DISTANCE.value(), distance); + + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score(1.0 - distance) + .embedding(toFloatArray(embedding)) + .build(); } private float[] toFloatArray(PGobject embedding) throws SQLException { diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index 3366c88296b..83c2e49e782 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -34,6 +34,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -113,20 +114,20 @@ static Stream provideFilters() { ); } - private static boolean isSortedByDistance(List docs) { + private static boolean isSortedBySimilarity(List docs) { - List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = docs.stream().map(Document::getScore).toList(); - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + if (CollectionUtils.isEmpty(scores) || scores.size() == 1) { return true; } - Iterator iter = distances.iterator(); - Float current; - Float previous = iter.next(); + Iterator iter = scores.iterator(); + Double current; + Double previous = iter.next(); while (iter.hasNext()) { current = iter.next(); - if (previous > current) { + if (previous < current) { return false; } previous = current; @@ -150,7 +151,7 @@ public void addAndSearch(String distanceType) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -289,7 +290,7 @@ public void documentUpdate(String distanceType) { Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -303,7 +304,7 @@ public void documentUpdate(String distanceType) { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); dropTable(context); }); @@ -326,20 +327,19 @@ public void searchWithThreshold(String distanceType) { assertThat(fullResult).hasSize(3); - assertThat(isSortedByDistance(fullResult)).isTrue(); + assertThat(isSortedBySimilarity(fullResult)).isTrue(); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1 - threshold)); + SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); dropTable(context); }); diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java index 1093656370b..d9ff49721a9 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java @@ -38,6 +38,7 @@ import io.pinecone.proto.Vector; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -60,13 +61,12 @@ * @author Christian Tzolov * @author Adam Bchouti * @author Soby Chacko + * @author Thomas Vitale */ public class PineconeVectorStore extends AbstractObservationVectorStore { public static final String CONTENT_FIELD_NAME = "document_content"; - public static final String DISTANCE_METADATA_FIELD_NAME = "distance"; - public final FilterExpressionConverter filterExpressionConverter = new PineconeFilterExpressionConverter(); private final EmbeddingModel embeddingModel; @@ -236,7 +236,12 @@ public List similaritySearch(SearchRequest request, String namespace) var content = metadataStruct.getFieldsOrThrow(this.pineconeContentFieldName).getStringValue(); Map metadata = extractMetadata(metadataStruct); metadata.put(this.pineconeDistanceMetadataFieldName, 1 - scoredVector.getScore()); - return new Document(id, content, metadata); + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score((double) scoredVector.getScore()) + .build(); }) .toList(); } @@ -298,6 +303,8 @@ public static final class PineconeVectorStoreConfig { private final String contentFieldName; + // TODO: Why is this field configurable? Can we remove this after standardizing + // the key? private final String distanceMetadataFieldName; private final PineconeConnectionConfig connectionConfig; @@ -357,7 +364,7 @@ public static final class Builder { private String contentFieldName = CONTENT_FIELD_NAME; - private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME; + private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value(); /** * Optional server-side timeout in seconds for all operations. Default: 20 diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java index 6e1eddc5f32..712efa19a69 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.PineconeVectorStore.PineconeVectorStoreConfig; @@ -46,6 +47,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+") public class PineconeVectorStoreIT { @@ -109,7 +111,7 @@ public void addAndSearchTest() { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -193,7 +195,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -214,7 +216,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -240,21 +242,22 @@ public void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index a2f97b756d3..bb4f6b3927d 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -35,6 +35,7 @@ import io.qdrant.client.grpc.Points.UpdateStatus; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -57,6 +58,7 @@ * @author Eddú Meléndez * @author Josh Long * @author Soby Chacko + * @author Thomas Vitale * @since 0.8.1 */ public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -65,8 +67,6 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements private static final String CONTENT_FIELD_NAME = "doc_content"; - private static final String DISTANCE_FIELD_NAME = "distance"; - private final EmbeddingModel embeddingModel; private final QdrantClient qdrantClient; @@ -208,12 +208,17 @@ private Document toDocument(ScoredPoint point) { try { var id = point.getId().getUuid(); - var payload = QdrantObjectFactory.toObjectMap(point.getPayloadMap()); - payload.put(DISTANCE_FIELD_NAME, 1 - point.getScore()); + var metadata = QdrantObjectFactory.toObjectMap(point.getPayloadMap()); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - point.getScore()); - var content = (String) payload.remove(CONTENT_FIELD_NAME); + var content = (String) metadata.remove(CONTENT_FIELD_NAME); - return new Document(id, content, payload); + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score((double) point.getScore()) + .build(); } catch (Exception e) { throw new RuntimeException(e); diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java index 995cd8cc3a5..0cb14fde45d 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.qdrant.QdrantContainer; @@ -107,7 +108,7 @@ public void addAndSearch() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -185,7 +186,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -200,7 +201,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); }); @@ -218,13 +219,13 @@ public void searchThresholdTest() { var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -232,7 +233,8 @@ public void searchThresholdTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java index df277067162..8759437926d 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -30,6 +30,7 @@ import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import redis.clients.jedis.JedisPooled; import redis.clients.jedis.Pipeline; import redis.clients.jedis.json.Path2; @@ -240,14 +241,21 @@ public List doSimilaritySearch(SearchRequest request) { private Document toDocument(redis.clients.jedis.search.Document doc) { var id = doc.getId().substring(this.config.prefix.length()); - var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) - : null; + var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) : ""; Map metadata = this.config.metadataFields.stream() .map(MetadataField::name) .filter(doc::hasProperty) .collect(Collectors.toMap(Function.identity(), doc::getString)); + // TODO: this seems wrong. The key is named "vector_store", but the value is the + // distance. Can we remove this after standardizing the metadata? metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc)); - return new Document(id, content, metadata); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - similarityScore(doc)); + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score((double) similarityScore(doc)) + .build(); } private float similarityScore(redis.clients.jedis.search.Document doc) { diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java index 4153a82ea6e..6c389b672d6 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java @@ -26,6 +26,7 @@ import com.redis.testcontainers.RedisStackContainer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import redis.clients.jedis.JedisPooled; @@ -50,6 +51,7 @@ /** * @author Julien Ruaux * @author Eddú Meléndez + * @author Thomas Vitale */ @Testcontainers class RedisVectorStoreIT { @@ -105,8 +107,9 @@ void addAndSearch() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); + assertThat(resultDoc.getMetadata()).hasSize(3); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME, + DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -190,6 +193,7 @@ void documentUpdate() { assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -205,6 +209,7 @@ void documentUpdate() { assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -223,24 +228,23 @@ void searchWithThreshold() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get(RedisVectorStore.DISTANCE_FIELD_NAME)) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); - + assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME, + DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java index b0fc72f1360..b805be62123 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java @@ -26,6 +26,7 @@ import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import org.typesense.api.Client; import org.typesense.api.FieldTypes; import org.typesense.model.CollectionResponse; @@ -57,6 +58,7 @@ * @author Pablo Sanchidrian Herrera * @author Soby Chacko * @author Christian Tzolov + * @author Thomas Vitale */ public class TypesenseVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -212,8 +214,13 @@ public List doSimilaritySearch(SearchRequest request) { String content = rawDocument.get(CONTENT_FIELD_NAME).toString(); Map metadata = rawDocument.get(METADATA_FIELD_NAME) instanceof Map ? (Map) rawDocument.get(METADATA_FIELD_NAME) : Map.of(); - metadata.put("distance", hit.getVectorDistance()); - return new Document(docId, content, metadata); + metadata.put(DocumentMetadata.DISTANCE.value(), hit.getVectorDistance()); + return Document.builder() + .id(docId) + .content(content) + .metadata(metadata) + .score(1.0 - hit.getVectorDistance()) + .build(); })) .toList(); diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java index 23cd9f03bb6..6d05a17fab9 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java @@ -26,6 +26,7 @@ import java.util.UUID; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.GenericContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -96,7 +97,7 @@ void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -114,7 +115,7 @@ void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -211,21 +212,22 @@ void searchWithThreshold() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); ((TypesenseVectorStore) vectorStore).dropCollection(); diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java index 9ce01c9f059..049d4a4f1f8 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -45,6 +45,7 @@ import io.weaviate.client.v1.graphql.query.fields.Fields; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -73,11 +74,10 @@ * @author Eddú Meléndez * @author Josh Long * @author Soby Chacko + * @author Thomas Vitale */ public class WeaviateVectorStore extends AbstractObservationVectorStore { - public static final String DOCUMENT_METADATA_DISTANCE_KEY_NAME = "distance"; - private static final String METADATA_FIELD_PREFIX = "meta_"; private static final String CONTENT_FIELD_NAME = "content"; @@ -367,7 +367,7 @@ private Document toDocument(Map item) { // Metadata Map metadata = new HashMap<>(); - metadata.put(DOCUMENT_METADATA_DISTANCE_KEY_NAME, 1 - certainty); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - certainty); try { String metadataJson = (String) item.get(METADATA_FIELD_NAME); @@ -382,10 +382,13 @@ private Document toDocument(Map item) { // Content String content = (String) item.get(CONTENT_FIELD_NAME); - var document = new Document(id, content, metadata); - document.setEmbedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding))); - - return document; + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .embedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding))) + .score(certainty) + .build(); } @Override diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java index b474cdaeded..7239b319b61 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java @@ -26,6 +26,7 @@ import io.weaviate.client.Config; import io.weaviate.client.WeaviateClient; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -101,7 +102,7 @@ public void addAndSearch() { assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -186,7 +187,7 @@ public void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -201,7 +202,7 @@ public void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -222,23 +223,22 @@ public void searchWithThreshold() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - double threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } From 3a740ec0f4d5efa20360fc96be76e084a1ba8c65 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Wed, 27 Nov 2024 22:29:33 +0100 Subject: [PATCH 2/2] Handle PR comments Signed-off-by: Thomas Vitale --- .../springframework/ai/document/Document.java | 36 +++++++++---------- .../vectorstore/ElasticsearchVectorStore.java | 7 ++-- .../ai/vectorstore/OpenSearchVectorStore.java | 7 ++-- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java index 45c2a5602fc..646d671c653 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java @@ -71,7 +71,7 @@ public class Document implements MediaContent { * measure. */ @Nullable - private Double score; + private final Double score; /** * Embedding of the document. Note: ephemeral field. @@ -90,10 +90,6 @@ public Document(@JsonProperty("content") String content) { this(content, new HashMap<>()); } - /** - * @deprecated Use builder instead: {@link Document#builder()}. - */ - @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String content, Map metadata) { this(content, metadata, new RandomIdGenerator()); } @@ -114,10 +110,6 @@ public Document(String content, Map metadata, IdGenerator idGene this(idGenerator.generateId(content, metadata), content, metadata); } - /** - * @deprecated Use builder instead: {@link Document#builder()}. - */ - @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String id, String content, Map metadata) { this(id, content, List.of(), metadata); } @@ -194,10 +186,6 @@ public Double getScore() { return this.score; } - public void setScore(@Nullable Double score) { - this.score = score; - } - /** * Return the embedding that were calculated. * @deprecated We are considering getting rid of this, please comment on @@ -233,20 +221,27 @@ public void setContentFormatter(ContentFormatter contentFormatter) { this.contentFormatter = contentFormatter; } - @Override - public int hashCode() { - return Objects.hash(id, content, media, metadata); + public Builder mutate() { + return new Builder().id(this.id) + .content(this.content) + .media(new ArrayList<>(this.media)) + .metadata(this.metadata) + .score(this.score); } @Override public boolean equals(Object o) { - if (this == o) - return true; if (o == null || getClass() != o.getClass()) return false; Document document = (Document) o; return Objects.equals(id, document.id) && Objects.equals(content, document.content) - && Objects.equals(media, document.media) && Objects.equals(metadata, document.metadata); + && Objects.equals(media, document.media) && Objects.equals(metadata, document.metadata) + && Objects.equals(score, document.score); + } + + @Override + public int hashCode() { + return Objects.hash(id, content, media, metadata, score); } @Override @@ -267,6 +262,7 @@ public static class Builder { private float[] embedding = new float[0]; + @Nullable private Double score; private IdGenerator idGenerator = new RandomIdGenerator(); @@ -314,7 +310,7 @@ public Builder embedding(float[] embedding) { return this; } - public Builder score(Double score) { + public Builder score(@Nullable Double score) { this.score = score; return this; } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 895c1f924c4..ee26f4b54e4 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -211,11 +211,12 @@ private String getElasticsearchQueryString(Filter.Expression filterExpression) { private Document toDocument(Hit hit) { Document document = hit.source(); + Document.Builder documentBuilder = document.mutate(); if (hit.score() != null) { - document.getMetadata().put(DocumentMetadata.DISTANCE.value(), 1 - normalizeSimilarityScore(hit.score())); - document.setScore(normalizeSimilarityScore(hit.score())); + documentBuilder.metadata(DocumentMetadata.DISTANCE.value(), 1 - normalizeSimilarityScore(hit.score())); + documentBuilder.score(normalizeSimilarityScore(hit.score())); } - return document; + return documentBuilder.build(); } // more info on score/distance calculation diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index 6199040f85b..66a8e9e1669 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -232,11 +232,12 @@ private List similaritySearch(org.opensearch.client.opensearch.core.Se private Document toDocument(Hit hit) { Document document = hit.source(); + Document.Builder documentBuilder = document.mutate(); if (hit.score() != null) { - document.setScore(hit.score()); - document.getMetadata().put(DocumentMetadata.DISTANCE.value(), 1 - hit.score().floatValue()); + documentBuilder.metadata(DocumentMetadata.DISTANCE.value(), 1 - hit.score().floatValue()); + documentBuilder.score(hit.score()); } - return document; + return documentBuilder.build(); } public boolean exists(String targetIndex) {