diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index 7035405c038..52ace259874 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -66,12 +66,12 @@ void defaultEmbedding() { @Test void embeddingBatchDocuments() throws Exception { assertThat(this.embeddingModel).isNotNull(); - List embedded = this.embeddingModel.embed( + List embeddings = this.embeddingModel.embed( List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")), OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), new TokenCountBatchingStrategy()); - assertThat(embedded.size()).isEqualTo(3); - embedded.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions())); + assertThat(embeddings.size()).isEqualTo(3); + embeddings.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions())); } @Test diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java index e354f1da87c..714d681a5d5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java @@ -31,7 +31,9 @@ public interface BatchingStrategy { /** * {@link EmbeddingModel} implementations can call this method to optimize embedding - * tokens. The incoming collection of {@link Document}s are split into su-batches. + * tokens. The incoming collection of {@link Document}s are split into sub-batches. It + * is important to preserve the order of the list of {@link Document}s when batching + * as they are mapped to their corresponding embeddings by their order. * @param documents to batch * @return a list of sub-batches that contain {@link Document}s. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java index c40ed34d20e..51a1ac03516 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java @@ -78,25 +78,23 @@ default List embed(List texts) { * @param options {@link EmbeddingOptions}. * @param batchingStrategy {@link BatchingStrategy}. * @return a list of float[] that represents the vectors for the incoming - * {@link Document}s. + * {@link Document}s. The returned list is expected to be in the same order of the + * {@link Document} list. */ default List embed(List documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) { Assert.notNull(documents, "Documents must not be null"); - List embeddings = new ArrayList<>(); - + List embeddings = new ArrayList<>(documents.size()); List> batch = batchingStrategy.batch(documents); - for (List subBatch : batch) { List texts = subBatch.stream().map(Document::getContent).toList(); EmbeddingRequest request = new EmbeddingRequest(texts, options); EmbeddingResponse response = this.call(request); for (int i = 0; i < subBatch.size(); i++) { - Document document = subBatch.get(i); - float[] output = response.getResults().get(i).getOutput(); - embeddings.add(output); - document.setEmbedding(output); + embeddings.add(response.getResults().get(i).getOutput()); } } + Assert.isTrue(embeddings.size() == documents.size(), + "Embeddings must have the same number as that of the documents"); return embeddings; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java index 713eefb9650..7ffc01fee5d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java @@ -17,7 +17,7 @@ package org.springframework.ai.embedding; import java.util.ArrayList; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -139,7 +139,9 @@ public List> batch(List documents) { List> batches = new ArrayList<>(); int currentSize = 0; List currentBatch = new ArrayList<>(); - Map documentTokens = new HashMap<>(); + // Make sure the documentTokens' entry order is preserved by making it a + // LinkedHashMap. + Map documentTokens = new LinkedHashMap<>(); for (Document document : documents) { int tokenCount = this.tokenCountEstimator diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java index c0ad68e3600..914efd57a37 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java @@ -203,13 +203,14 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) { public void doAdd(List documents) { // Batch the documents based on the batching strategy - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); // Create a list to hold both the CosmosItemOperation and the corresponding // document ID List> itemOperationsWithIds = documents.stream().map(doc -> { - CosmosItemOperation operation = CosmosBulkOperations - .getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId())); + CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation( + mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId())); return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID // with the operation }).toList(); diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index 6200738b928..6107f7492e6 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -224,12 +224,13 @@ public void doAdd(List documents) { return; // nothing to do; } - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); final var searchDocuments = documents.stream().map(document -> { SearchDocument searchDocument = new SearchDocument(); searchDocument.put(ID_FIELD_NAME, document.getId()); - searchDocument.put(EMBEDDING_FIELD_NAME, document.getEmbedding()); + searchDocument.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document))); searchDocument.put(CONTENT_FIELD_NAME, document.getContent()); searchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString()); @@ -324,7 +325,6 @@ public List doSimilaritySearch(SearchRequest request) { metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore()); final Document doc = new Document(entry.id(), entry.content(), metadata); - doc.setEmbedding(EmbeddingUtils.toPrimitive(entry.embedding())); return doc; diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 3ad7f5a916d..7b9bf975281 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -182,7 +182,8 @@ private static Float[] toFloatArray(float[] embedding) { public void doAdd(List documents) { var futures = new CompletableFuture[documents.size()]; - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); int i = 0; for (Document d : documents) { @@ -197,7 +198,8 @@ public void doAdd(List documents) { builder = builder.setString(this.conf.schema.content(), d.getContent()) .setVector(this.conf.schema.embedding(), - CqlVector.newInstance(EmbeddingUtils.toList(d.getEmbedding())), Float.class); + CqlVector.newInstance(EmbeddingUtils.toList(embeddings.get(documents.indexOf(d)))), + Float.class); for (var metadataColumn : this.conf.schema.metadataColumns() .stream() @@ -260,11 +262,6 @@ public List doSimilaritySearch(SearchRequest request) { } } Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields); - - if (this.conf.returnEmbeddings) { - doc.setEmbedding(EmbeddingUtils - .toPrimitive(row.getVector(this.conf.schema.embedding(), Float.class).stream().toList())); - } documents.add(doc); } return documents; diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java index 007d2c08c3a..518eb622ac4 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java @@ -90,6 +90,8 @@ public final class CassandraVectorStoreConfig implements AutoCloseable { final boolean disallowSchemaChanges; + // TODO: Remove this flag as the document no longer holds embeddings. + @Deprecated(since = "1.0.0-M5", forRemoval = true) final boolean returnEmbeddings; final DocumentIdTranslator documentIdTranslator; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java index e17091d5ea0..d6298f166c6 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java @@ -121,18 +121,12 @@ void addAndSearch() { List documents = documents(); store.add(documents); - for (Document d : documents) { - assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(), - e -> assertThat(e).isNotEmpty()); - } List results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId()); - assertThat(resultDoc.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNull(), - e -> assertThat(e).isEmpty()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); @@ -158,17 +152,12 @@ void addAndSearchReturnEmbeddings() { try (CassandraVectorStore store = createTestStore(context, builder)) { List documents = documents(); store.add(documents); - for (Document d : documents) { - assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(), - e -> assertThat(e).isNotEmpty()); - } List results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId()); - assertThat(resultDoc.getEmbedding()).isNotEmpty(); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 49deb147d65..923824ef92d 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -145,14 +145,14 @@ public void doAdd(@NonNull List documents) { List contents = new ArrayList<>(); List embeddings = new ArrayList<>(); - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List documentEmbeddings = this.embeddingModel.embed(documents, + EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); for (Document document : documents) { ids.add(document.getId()); metadatas.add(document.getMetadata()); contents.add(document.getContent()); - document.setEmbedding(document.getEmbedding()); - embeddings.add(document.getEmbedding()); + embeddings.add(documentEmbeddings.get(documents.indexOf(document))); } this.chromaApi.upsertEmbeddings(this.collectionId, @@ -193,9 +193,7 @@ public Optional doDelete(@NonNull List idList) { metadata = new HashMap<>(); } metadata.put(DISTANCE_FIELD_NAME, distance); - Document document = new Document(id, content, metadata); - document.setEmbedding(chromaEmbedding.embedding()); - responseDocuments.add(document); + responseDocuments.add(new Document(id, content, metadata)); } } diff --git a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java index e9c64e79461..b76618c9958 100644 --- a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java +++ b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java @@ -165,9 +165,9 @@ public CoherenceVectorStore setForcedNormalization(boolean forcedNormalization) public void add(final List documents) { Map chunks = new HashMap<>((int) Math.ceil(documents.size() / 0.75f)); for (Document doc : documents) { - doc.setEmbedding(this.embeddingModel.embed(doc)); var id = toChunkId(doc.getId()); - var chunk = new DocumentChunk(doc.getContent(), doc.getMetadata(), toFloat32Vector(doc.getEmbedding())); + var chunk = new DocumentChunk(doc.getContent(), doc.getMetadata(), + toFloat32Vector(this.embeddingModel.embed(doc))); chunks.put(id, chunk); } this.documentChunks.putAll(chunks); diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 9c058936b44..c91f920b10b 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -71,6 +71,7 @@ * @author Soby Chacko * @author Christian Tzolov * @author Thomas Vitale + * @author Ilayaperumal Gopinathan * @since 1.0.0 */ public class ElasticsearchVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -132,11 +133,14 @@ public void doAdd(List documents) { } BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); for (Document document : documents) { - bulkRequestBuilder.operations(op -> op - .index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document))); + ElasticSearchDocument doc = new ElasticSearchDocument(document.getId(), document.getContent(), + document.getMetadata(), embeddings.get(documents.indexOf(document))); + bulkRequestBuilder.operations( + op -> op.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(doc))); } BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build()); if (bulkRequest.errors()) { @@ -277,4 +281,15 @@ private String getSimilarityMetric() { return SIMILARITY_TYPE_MAPPING.get(this.options.getSimilarity()).value(); } + /** + * The representation of {@link Document} along with its embedding. + * + * @param id The id of the document + * @param content The content of the document + * @param metadata The metadata of the document + * @param embedding The vectors representing the content of the document + */ + public record ElasticSearchDocument(String id, String content, Map metadata, float[] embedding) { + } + } diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index 753eb1e7bdb..52dd86eb7c6 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -209,10 +209,11 @@ public String getIndex() { @Override public void doAdd(List documents) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); UploadRequest upload = new UploadRequest(documents.stream() - .map(document -> new UploadRequest.Embedding(document.getId(), document.getEmbedding(), DOCUMENT_FIELD, - document.getContent(), document.getMetadata())) + .map(document -> new UploadRequest.Embedding(document.getId(), embeddings.get(documents.indexOf(document)), + DOCUMENT_FIELD, document.getContent(), document.getMetadata())) .toList()); String embeddingsJson = null; diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java index 9ed8d8b0157..fbeccfa7828 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java @@ -160,7 +160,8 @@ public void doAdd(List documents) { List> embeddingArray = new ArrayList<>(); // TODO: Need to customize how we pass the embedding options - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); for (Document document : documents) { docIdArray.add(document.getId()); @@ -168,7 +169,7 @@ public void doAdd(List documents) { // the content used to compute the embeddings contentArray.add(document.getContent()); metadataArray.add(new JSONObject(document.getMetadata())); - embeddingArray.add(EmbeddingUtils.toList(document.getEmbedding())); + embeddingArray.add(EmbeddingUtils.toList(embeddings.get(documents.indexOf(document)))); } List fields = new ArrayList<>(); diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java index 2e40db851c6..a8785f20d25 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java @@ -50,6 +50,7 @@ * @author Soby Chacko * @author Christian Tzolov * @author Thomas Vitale + * @author Ilayaperumal Gopinathan * @since 1.0.0 */ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -173,18 +174,17 @@ private Document mapMongoDocument(org.bson.Document mongoDocument, float[] query String id = mongoDocument.getString(ID_FIELD_NAME); String content = mongoDocument.getString(CONTENT_FIELD_NAME); Map metadata = mongoDocument.get(METADATA_FIELD_NAME, org.bson.Document.class); - - Document document = new Document(id, content, metadata); - document.setEmbedding(queryEmbedding); - - return document; + return new Document(id, content, metadata); } @Override public void doAdd(List documents) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); for (Document document : documents) { - this.mongoTemplate.save(document, this.config.collectionName); + MongoDBDocument mdbDocument = new MongoDBDocument(document.getId(), document.getContent(), + document.getMetadata(), embeddings.get(documents.indexOf(document))); + this.mongoTemplate.save(mdbDocument, this.config.collectionName); } } @@ -332,4 +332,15 @@ public MongoDBVectorStoreConfig build() { } + /** + * The representation of {@link Document} along with its embedding. + * + * @param id The id of the document + * @param content The content of the document + * @param metadata The metadata of the document + * @param embedding The vectors representing the content of the document + */ + public record MongoDBDocument(String id, String content, Map metadata, float[] embedding) { + } + } diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java index 0493d893ca0..07c76f562b1 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java @@ -107,9 +107,12 @@ public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVecto @Override public void doAdd(List documents) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); - var rows = documents.stream().map(this::documentToRecord).toList(); + var rows = documents.stream() + .map(document -> documentToRecord(document, embeddings.get(documents.indexOf(document)))) + .toList(); try (var session = this.driver.session()) { var statement = """ @@ -203,8 +206,7 @@ public void afterPropertiesSet() { } } - private Map documentToRecord(Document document) { - document.setEmbedding(document.getEmbedding()); + private Map documentToRecord(Document document, float[] embedding) { var row = new HashMap(); @@ -216,7 +218,7 @@ private Map documentToRecord(Document document) { document.getMetadata().forEach((k, v) -> properties.put("metadata." + k, Values.value(v))); row.put("properties", properties); - row.put(this.config.embeddingProperty, Values.value(document.getEmbedding())); + row.put(this.config.embeddingProperty, Values.value(embedding)); return row; } diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java index a9e3b63eb79..d2ad15e564f 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java @@ -204,7 +204,8 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode @Override public void doAdd(final List documents) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); this.jdbcTemplate.batchUpdate(getIngestStatement(), new BatchPreparedStatementSetter() { @Override @@ -212,7 +213,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { final Document document = documents.get(i); final String content = document.getContent(); final byte[] json = toJson(document.getMetadata()); - final VECTOR embeddingVector = toVECTOR(document.getEmbedding()); + final VECTOR embeddingVector = toVECTOR(embeddings.get(documents.indexOf(document))); org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue(ps, 1, Types.VARCHAR, document.getId()); @@ -651,10 +652,7 @@ public Document mapRow(ResultSet rs, int rowNum) throws SQLException { final Map metadata = getMap(rs.getObject(3, OracleJsonValue.class)); metadata.put("distance", rs.getDouble(5)); - final Document document = new Document(rs.getString(1), rs.getString(2), metadata); - final float[] embedding = rs.getObject(4, float[].class); - document.setEmbedding(embedding); - return document; + return new Document(rs.getString(1), rs.getString(2), metadata); } private Map getMap(OracleJsonValue value) { diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 69cd79b41c4..d2dcaacf133 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -195,10 +195,11 @@ public PgDistanceType getDistanceType() { @Override public void doAdd(List documents) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); List> batchedDocuments = batchDocuments(documents); - batchedDocuments.forEach(this::insertOrUpdateBatch); + batchedDocuments.forEach(batchDocument -> insertOrUpdateBatch(batchDocument, documents, embeddings)); } private List> batchDocuments(List documents) { @@ -209,7 +210,7 @@ private List> batchDocuments(List documents) { return batches; } - private void insertOrUpdateBatch(List batch) { + private void insertOrUpdateBatch(List batch, List documents, List embeddings) { String sql = "INSERT INTO " + getFullyQualifiedTableName() + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; @@ -222,7 +223,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { var document = batch.get(i); var content = document.getContent(); var json = toJson(document.getMetadata()); - var embedding = document.getEmbedding(); + var embedding = embeddings.get(documents.indexOf(document)); var pGvector = new PGvector(embedding); StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, @@ -504,10 +505,7 @@ public Document mapRow(ResultSet rs, int rowNum) throws SQLException { Map metadata = toMap(pgMetadata); metadata.put(COLUMN_DISTANCE, distance); - Document document = new Document(id, content, metadata); - document.setEmbedding(toFloatArray(embedding)); - - return document; + return new Document(id, content, metadata); } private float[] toFloatArray(PGobject embedding) throws SQLException { diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java index e2a37cd3761..6325ac71810 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java @@ -146,9 +146,6 @@ void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { EmbeddingModel embeddingModel = mock(EmbeddingModel.class); Mockito.doAnswer(invocationOnMock -> { - Object[] arguments = invocationOnMock.getArguments(); - List documents = (List) arguments[0]; - documents.forEach(d -> d.setEmbedding(this.embed)); return List.of(this.embed, this.embed); }).when(embeddingModel).embed(ArgumentMatchers.any(), any(), any()); given(embeddingModel.embed(any(String.class))).willReturn(this.embed); diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java index 1093656370b..46f53ddf6a7 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java @@ -124,11 +124,12 @@ public PineconeVectorStore(PineconeVectorStoreConfig config, EmbeddingModel embe * @param namespace The namespace to add the documents to */ public void add(List documents, String namespace) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); List upsertVectors = documents.stream() .map(document -> Vector.newBuilder() .setId(document.getId()) - .addAllValues(EmbeddingUtils.toList(document.getEmbedding())) + .addAllValues(EmbeddingUtils.toList(embeddings.get(documents.indexOf(document)))) .setMetadata(metadataToStruct(document)) .build()) .toList(); diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index 825c50f7b5b..3541ef72793 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -127,12 +127,13 @@ public void doAdd(List documents) { try { // Compute and assign an embedding to the document. - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); List points = documents.stream() .map(document -> PointStruct.newBuilder() .setId(io.qdrant.client.PointIdFactory.id(UUID.fromString(document.getId()))) - .setVectors(io.qdrant.client.VectorsFactory.vectors(document.getEmbedding())) + .setVectors(io.qdrant.client.VectorsFactory.vectors(embeddings.get(documents.indexOf(document)))) .putAllPayload(toPayload(document)) .build()) .toList(); diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java index df277067162..615643ce415 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -162,12 +162,12 @@ public JedisPooled getJedis() { public void doAdd(List documents) { try (Pipeline pipeline = this.jedis.pipelined()) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); for (Document document : documents) { - document.setEmbedding(document.getEmbedding()); var fields = new HashMap(); - fields.put(this.config.embeddingFieldName, document.getEmbedding()); + fields.put(this.config.embeddingFieldName, embeddings.get(documents.indexOf(document))); fields.put(this.config.contentFieldName, document.getContent()); fields.putAll(document.getMetadata()); pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields); diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java index b0fc72f1360..c6a8a1466d9 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java @@ -123,14 +123,15 @@ public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel, Typese public void doAdd(List documents) { Assert.notNull(documents, "Documents must not be null"); - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); List> documentList = documents.stream().map(document -> { HashMap typesenseDoc = new HashMap<>(); typesenseDoc.put(DOC_ID_FIELD_NAME, document.getId()); typesenseDoc.put(CONTENT_FIELD_NAME, document.getContent()); typesenseDoc.put(METADATA_FIELD_NAME, document.getMetadata()); - typesenseDoc.put(EMBEDDING_FIELD_NAME, document.getEmbedding()); + typesenseDoc.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document))); return typesenseDoc; }).toList(); diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java index 9ce01c9f059..4c88a91a992 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -197,9 +197,12 @@ public void doAdd(List documents) { return; } - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); - List weaviateObjects = documents.stream().map(this::toWeaviateObject).toList(); + List weaviateObjects = documents.stream() + .map(document -> toWeaviateObject(document, documents, embeddings)) + .toList(); Result response = this.weaviateClient.batch() .objectsBatcher() @@ -235,7 +238,7 @@ public void doAdd(List documents) { } } - private WeaviateObject toWeaviateObject(Document document) { + private WeaviateObject toWeaviateObject(Document document, List documents, List embeddings) { // https://weaviate.io/developers/weaviate/config-refs/datatypes Map fields = new HashMap<>(); @@ -259,7 +262,7 @@ private WeaviateObject toWeaviateObject(Document document) { return WeaviateObject.builder() .className(this.weaviateObjectClass) .id(document.getId()) - .vector(EmbeddingUtils.toFloatArray(document.getEmbedding())) + .vector(EmbeddingUtils.toFloatArray(embeddings.get(documents.indexOf(document)))) .properties(fields) .build(); } @@ -382,10 +385,7 @@ private Document toDocument(Map item) { // Content String content = (String) item.get(CONTENT_FIELD_NAME); - var document = new Document(id, content, metadata); - document.setEmbedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding))); - - return document; + return new Document(id, content, metadata); } @Override