Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ void defaultEmbedding() {
@Test
void embeddingBatchDocuments() throws Exception {
assertThat(this.embeddingModel).isNotNull();
List<float[]> embedded = this.embeddingModel.embed(
List<float[]> embeddings = this.embeddingModel.embed(
List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")),
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
new TokenCountBatchingStrategy());
assertThat(embedded.size()).isEqualTo(3);
embedded.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions()));
assertThat(embeddings.size()).isEqualTo(3);
embeddings.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions()));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ public interface BatchingStrategy {

/**
* {@link EmbeddingModel} implementations can call this method to optimize embedding
* tokens. The incoming collection of {@link Document}s are split into su-batches.
* tokens. The incoming collection of {@link Document}s are split into sub-batches. It
* is important to preserve the order of the list of {@link Document}s when batching
* as they are mapped to their corresponding embeddings by their order.
* @param documents to batch
* @return a list of sub-batches that contain {@link Document}s.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,23 @@ default List<float[]> embed(List<String> texts) {
* @param options {@link EmbeddingOptions}.
* @param batchingStrategy {@link BatchingStrategy}.
* @return a list of float[] that represents the vectors for the incoming
* {@link Document}s.
* {@link Document}s. The returned list is expected to be in the same order of the
* {@link Document} list.
*/
default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
Assert.notNull(documents, "Documents must not be null");
List<float[]> embeddings = new ArrayList<>();

List<float[]> embeddings = new ArrayList<>(documents.size());
List<List<Document>> batch = batchingStrategy.batch(documents);

for (List<Document> subBatch : batch) {
List<String> texts = subBatch.stream().map(Document::getContent).toList();
EmbeddingRequest request = new EmbeddingRequest(texts, options);
EmbeddingResponse response = this.call(request);
for (int i = 0; i < subBatch.size(); i++) {
Document document = subBatch.get(i);
float[] output = response.getResults().get(i).getOutput();
embeddings.add(output);
document.setEmbedding(output);
embeddings.add(response.getResults().get(i).getOutput());
}
}
Assert.isTrue(embeddings.size() == documents.size(),
"Embeddings must have the same number as that of the documents");
return embeddings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.springframework.ai.embedding;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -139,7 +139,9 @@ public List<List<Document>> batch(List<Document> documents) {
List<List<Document>> batches = new ArrayList<>();
int currentSize = 0;
List<Document> currentBatch = new ArrayList<>();
Map<Document, Integer> documentTokens = new HashMap<>();
// Make sure the documentTokens' entry order is preserved by making it a
// LinkedHashMap.
Map<Document, Integer> documentTokens = new LinkedHashMap<>();

for (Document document : documents) {
int tokenCount = this.tokenCountEstimator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) {
public void doAdd(List<Document> documents) {

// Batch the documents based on the batching strategy
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

// Create a list to hold both the CosmosItemOperation and the corresponding
// document ID
List<ImmutablePair<String, CosmosItemOperation>> itemOperationsWithIds = documents.stream().map(doc -> {
CosmosItemOperation operation = CosmosBulkOperations
.getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId()));
CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation(
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId()));
return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
// with the operation
}).toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,13 @@ public void doAdd(List<Document> documents) {
return; // nothing to do;
}

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

final var searchDocuments = documents.stream().map(document -> {
SearchDocument searchDocument = new SearchDocument();
searchDocument.put(ID_FIELD_NAME, document.getId());
searchDocument.put(EMBEDDING_FIELD_NAME, document.getEmbedding());
searchDocument.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document)));
searchDocument.put(CONTENT_FIELD_NAME, document.getContent());
searchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());

Expand Down Expand Up @@ -324,7 +325,6 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore());

final Document doc = new Document(entry.id(), entry.content(), metadata);
doc.setEmbedding(EmbeddingUtils.toPrimitive(entry.embedding()));

return doc;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ private static Float[] toFloatArray(float[] embedding) {
public void doAdd(List<Document> documents) {
var futures = new CompletableFuture[documents.size()];

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

int i = 0;
for (Document d : documents) {
Expand All @@ -197,7 +198,8 @@ public void doAdd(List<Document> documents) {

builder = builder.setString(this.conf.schema.content(), d.getContent())
.setVector(this.conf.schema.embedding(),
CqlVector.newInstance(EmbeddingUtils.toList(d.getEmbedding())), Float.class);
CqlVector.newInstance(EmbeddingUtils.toList(embeddings.get(documents.indexOf(d)))),
Float.class);

for (var metadataColumn : this.conf.schema.metadataColumns()
.stream()
Expand Down Expand Up @@ -260,11 +262,6 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
}
}
Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields);

if (this.conf.returnEmbeddings) {
doc.setEmbedding(EmbeddingUtils
.toPrimitive(row.getVector(this.conf.schema.embedding(), Float.class).stream().toList()));
}
documents.add(doc);
}
return documents;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public final class CassandraVectorStoreConfig implements AutoCloseable {

final boolean disallowSchemaChanges;

// TODO: Remove this flag as the document no longer holds embeddings.
@Deprecated(since = "1.0.0-M5", forRemoval = true)
final boolean returnEmbeddings;

final DocumentIdTranslator documentIdTranslator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,12 @@ void addAndSearch() {

List<Document> documents = documents();
store.add(documents);
for (Document d : documents) {
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
e -> assertThat(e).isNotEmpty());
}

List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
assertThat(resultDoc.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNull(),
e -> assertThat(e).isEmpty());

assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
Expand All @@ -158,17 +152,12 @@ void addAndSearchReturnEmbeddings() {
try (CassandraVectorStore store = createTestStore(context, builder)) {
List<Document> documents = documents();
store.add(documents);
for (Document d : documents) {
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
e -> assertThat(e).isNotEmpty());
}

List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
assertThat(resultDoc.getEmbedding()).isNotEmpty();

assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ public void doAdd(@NonNull List<Document> documents) {
List<String> contents = new ArrayList<>();
List<float[]> embeddings = new ArrayList<>();

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> documentEmbeddings = this.embeddingModel.embed(documents,
EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);

for (Document document : documents) {
ids.add(document.getId());
metadatas.add(document.getMetadata());
contents.add(document.getContent());
document.setEmbedding(document.getEmbedding());
embeddings.add(document.getEmbedding());
embeddings.add(documentEmbeddings.get(documents.indexOf(document)));
}

this.chromaApi.upsertEmbeddings(this.collectionId,
Expand Down Expand Up @@ -193,9 +193,7 @@ public Optional<Boolean> doDelete(@NonNull List<String> idList) {
metadata = new HashMap<>();
}
metadata.put(DISTANCE_FIELD_NAME, distance);
Document document = new Document(id, content, metadata);
document.setEmbedding(chromaEmbedding.embedding());
responseDocuments.add(document);
responseDocuments.add(new Document(id, content, metadata));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ public CoherenceVectorStore setForcedNormalization(boolean forcedNormalization)
public void add(final List<Document> documents) {
Map<DocumentChunk.Id, DocumentChunk> chunks = new HashMap<>((int) Math.ceil(documents.size() / 0.75f));
for (Document doc : documents) {
doc.setEmbedding(this.embeddingModel.embed(doc));
var id = toChunkId(doc.getId());
var chunk = new DocumentChunk(doc.getContent(), doc.getMetadata(), toFloat32Vector(doc.getEmbedding()));
var chunk = new DocumentChunk(doc.getContent(), doc.getMetadata(),
toFloat32Vector(this.embeddingModel.embed(doc)));
chunks.put(id, chunk);
}
this.documentChunks.putAll(chunks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
* @author Soby Chacko
* @author Christian Tzolov
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @since 1.0.0
*/
public class ElasticsearchVectorStore extends AbstractObservationVectorStore implements InitializingBean {
Expand Down Expand Up @@ -132,11 +133,14 @@ public void doAdd(List<Document> documents) {
}
BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

for (Document document : documents) {
bulkRequestBuilder.operations(op -> op
.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document)));
ElasticSearchDocument doc = new ElasticSearchDocument(document.getId(), document.getContent(),
document.getMetadata(), embeddings.get(documents.indexOf(document)));
bulkRequestBuilder.operations(
op -> op.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(doc)));
}
BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build());
if (bulkRequest.errors()) {
Expand Down Expand Up @@ -277,4 +281,15 @@ private String getSimilarityMetric() {
return SIMILARITY_TYPE_MAPPING.get(this.options.getSimilarity()).value();
}

/**
* The representation of {@link Document} along with its embedding.
*
* @param id The id of the document
* @param content The content of the document
* @param metadata The metadata of the document
* @param embedding The vectors representing the content of the document
*/
public record ElasticSearchDocument(String id, String content, Map<String, Object> metadata, float[] embedding) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,11 @@ public String getIndex() {

@Override
public void doAdd(List<Document> documents) {
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);
UploadRequest upload = new UploadRequest(documents.stream()
.map(document -> new UploadRequest.Embedding(document.getId(), document.getEmbedding(), DOCUMENT_FIELD,
document.getContent(), document.getMetadata()))
.map(document -> new UploadRequest.Embedding(document.getId(), embeddings.get(documents.indexOf(document)),
DOCUMENT_FIELD, document.getContent(), document.getMetadata()))
.toList());

String embeddingsJson = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,16 @@ public void doAdd(List<Document> documents) {
List<List<Float>> embeddingArray = new ArrayList<>();

// TODO: Need to customize how we pass the embedding options
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

for (Document document : documents) {
docIdArray.add(document.getId());
// Use a (future) DocumentTextLayoutFormatter instance to extract
// the content used to compute the embeddings
contentArray.add(document.getContent());
metadataArray.add(new JSONObject(document.getMetadata()));
embeddingArray.add(EmbeddingUtils.toList(document.getEmbedding()));
embeddingArray.add(EmbeddingUtils.toList(embeddings.get(documents.indexOf(document))));
}

List<InsertParam.Field> fields = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
* @author Soby Chacko
* @author Christian Tzolov
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @since 1.0.0
*/
public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore implements InitializingBean {
Expand Down Expand Up @@ -173,18 +174,17 @@ private Document mapMongoDocument(org.bson.Document mongoDocument, float[] query
String id = mongoDocument.getString(ID_FIELD_NAME);
String content = mongoDocument.getString(CONTENT_FIELD_NAME);
Map<String, Object> metadata = mongoDocument.get(METADATA_FIELD_NAME, org.bson.Document.class);

Document document = new Document(id, content, metadata);
document.setEmbedding(queryEmbedding);

return document;
return new Document(id, content, metadata);
}

@Override
public void doAdd(List<Document> documents) {
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);
for (Document document : documents) {
this.mongoTemplate.save(document, this.config.collectionName);
MongoDBDocument mdbDocument = new MongoDBDocument(document.getId(), document.getContent(),
document.getMetadata(), embeddings.get(documents.indexOf(document)));
this.mongoTemplate.save(mdbDocument, this.config.collectionName);
}
}

Expand Down Expand Up @@ -332,4 +332,15 @@ public MongoDBVectorStoreConfig build() {

}

/**
* The representation of {@link Document} along with its embedding.
*
* @param id The id of the document
* @param content The content of the document
* @param metadata The metadata of the document
* @param embedding The vectors representing the content of the document
*/
public record MongoDBDocument(String id, String content, Map<String, Object> metadata, float[] embedding) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVecto
@Override
public void doAdd(List<Document> documents) {

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

var rows = documents.stream().map(this::documentToRecord).toList();
var rows = documents.stream()
.map(document -> documentToRecord(document, embeddings.get(documents.indexOf(document))))
.toList();

try (var session = this.driver.session()) {
var statement = """
Expand Down Expand Up @@ -203,8 +206,7 @@ public void afterPropertiesSet() {
}
}

private Map<String, Object> documentToRecord(Document document) {
document.setEmbedding(document.getEmbedding());
private Map<String, Object> documentToRecord(Document document, float[] embedding) {

var row = new HashMap<String, Object>();

Expand All @@ -216,7 +218,7 @@ private Map<String, Object> documentToRecord(Document document) {
document.getMetadata().forEach((k, v) -> properties.put("metadata." + k, Values.value(v)));
row.put("properties", properties);

row.put(this.config.embeddingProperty, Values.value(document.getEmbedding()));
row.put(this.config.embeddingProperty, Values.value(embedding));
return row;
}

Expand Down
Loading