Skip to content

Commit 7181900

Browse files
committed
GH-1826 Fix EmbeddingModel's usage on Document#embedding
- Since the Document object's reference to the `embedding` is deprecated and will be removed, the VectorStore implementations require a way to store the embedding of the corresponding Document objects - One way to fix this is, to have the EmbeddingModel#embed to return the embeddings in the same order as that of the Documents passed to it. - Since both the Document and embedding collections use the List object, their iteration operation will make sure to keep them in line with the same order. - A fix is required to preserve the order when batching strategy is applied. - Updated the Javadoc for BatchingStrategy - Fixed the Document List order in TokenCountBatchingStrategy - Refactored the vector store implementations to update this change Resolves #GH-1826
1 parent 23a3d13 commit 7181900

File tree

24 files changed

+115
-100
lines changed

24 files changed

+115
-100
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ void defaultEmbedding() {
6666
@Test
6767
void embeddingBatchDocuments() throws Exception {
6868
assertThat(this.embeddingModel).isNotNull();
69-
List<float[]> embedded = this.embeddingModel.embed(
69+
List<float[]> embeddings = this.embeddingModel.embed(
7070
List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")),
7171
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
7272
new TokenCountBatchingStrategy());
73-
assertThat(embedded.size()).isEqualTo(3);
74-
embedded.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions()));
73+
assertThat(embeddings.size()).isEqualTo(3);
74+
embeddings.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions()));
7575
}
7676

7777
@Test

spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ public interface BatchingStrategy {
3131

3232
/**
3333
* {@link EmbeddingModel} implementations can call this method to optimize embedding
34-
* tokens. The incoming collection of {@link Document}s are split into su-batches.
34+
* tokens. The incoming collection of {@link Document}s are split into sub-batches. It
35+
* is important to preserve the order of the list of {@link Document}s when batching
36+
* as they are mapped to their corresponding embeddings by their order.
3537
* @param documents to batch
3638
* @return a list of sub-batches that contain {@link Document}s.
3739
*/

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,23 @@ default List<float[]> embed(List<String> texts) {
7878
* @param options {@link EmbeddingOptions}.
7979
* @param batchingStrategy {@link BatchingStrategy}.
8080
* @return a list of float[] that represents the vectors for the incoming
81-
* {@link Document}s.
81+
* {@link Document}s. The returned list is expected to be in the same order of the
82+
* {@link Document} list.
8283
*/
8384
default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
8485
Assert.notNull(documents, "Documents must not be null");
85-
List<float[]> embeddings = new ArrayList<>();
86-
86+
List<float[]> embeddings = new ArrayList<>(documents.size());
8787
List<List<Document>> batch = batchingStrategy.batch(documents);
88-
8988
for (List<Document> subBatch : batch) {
9089
List<String> texts = subBatch.stream().map(Document::getContent).toList();
9190
EmbeddingRequest request = new EmbeddingRequest(texts, options);
9291
EmbeddingResponse response = this.call(request);
9392
for (int i = 0; i < subBatch.size(); i++) {
94-
Document document = subBatch.get(i);
95-
float[] output = response.getResults().get(i).getOutput();
96-
embeddings.add(output);
97-
document.setEmbedding(output);
93+
embeddings.add(response.getResults().get(i).getOutput());
9894
}
9995
}
96+
Assert.isTrue(embeddings.size() == documents.size(),
97+
"Embeddings must have the same number as that of the documents");
10098
return embeddings;
10199
}
102100

spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package org.springframework.ai.embedding;
1818

1919
import java.util.ArrayList;
20-
import java.util.HashMap;
20+
import java.util.LinkedHashMap;
2121
import java.util.List;
2222
import java.util.Map;
2323

@@ -139,7 +139,9 @@ public List<List<Document>> batch(List<Document> documents) {
139139
List<List<Document>> batches = new ArrayList<>();
140140
int currentSize = 0;
141141
List<Document> currentBatch = new ArrayList<>();
142-
Map<Document, Integer> documentTokens = new HashMap<>();
142+
// Make sure the documentTokens' entry order is preserved by making it a
143+
// LinkedHashMap.
144+
Map<Document, Integer> documentTokens = new LinkedHashMap<>();
143145

144146
for (Document document : documents) {
145147
int tokenCount = this.tokenCountEstimator

vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,14 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) {
203203
public void doAdd(List<Document> documents) {
204204

205205
// Batch the documents based on the batching strategy
206-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
206+
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
207+
this.batchingStrategy);
207208

208209
// Create a list to hold both the CosmosItemOperation and the corresponding
209210
// document ID
210211
List<ImmutablePair<String, CosmosItemOperation>> itemOperationsWithIds = documents.stream().map(doc -> {
211-
CosmosItemOperation operation = CosmosBulkOperations
212-
.getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId()));
212+
CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation(
213+
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId()));
213214
return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
214215
// with the operation
215216
}).toList();

vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,13 @@ public void doAdd(List<Document> documents) {
224224
return; // nothing to do;
225225
}
226226

227-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
227+
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
228+
this.batchingStrategy);
228229

229230
final var searchDocuments = documents.stream().map(document -> {
230231
SearchDocument searchDocument = new SearchDocument();
231232
searchDocument.put(ID_FIELD_NAME, document.getId());
232-
searchDocument.put(EMBEDDING_FIELD_NAME, document.getEmbedding());
233+
searchDocument.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document)));
233234
searchDocument.put(CONTENT_FIELD_NAME, document.getContent());
234235
searchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());
235236

@@ -324,7 +325,6 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
324325
metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore());
325326

326327
final Document doc = new Document(entry.id(), entry.content(), metadata);
327-
doc.setEmbedding(EmbeddingUtils.toPrimitive(entry.embedding()));
328328

329329
return doc;
330330

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ private static Float[] toFloatArray(float[] embedding) {
182182
public void doAdd(List<Document> documents) {
183183
var futures = new CompletableFuture[documents.size()];
184184

185-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
185+
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
186+
this.batchingStrategy);
186187

187188
int i = 0;
188189
for (Document d : documents) {
@@ -197,7 +198,8 @@ public void doAdd(List<Document> documents) {
197198

198199
builder = builder.setString(this.conf.schema.content(), d.getContent())
199200
.setVector(this.conf.schema.embedding(),
200-
CqlVector.newInstance(EmbeddingUtils.toList(d.getEmbedding())), Float.class);
201+
CqlVector.newInstance(EmbeddingUtils.toList(embeddings.get(documents.indexOf(d)))),
202+
Float.class);
201203

202204
for (var metadataColumn : this.conf.schema.metadataColumns()
203205
.stream()
@@ -260,11 +262,6 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
260262
}
261263
}
262264
Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields);
263-
264-
if (this.conf.returnEmbeddings) {
265-
doc.setEmbedding(EmbeddingUtils
266-
.toPrimitive(row.getVector(this.conf.schema.embedding(), Float.class).stream().toList()));
267-
}
268265
documents.add(doc);
269266
}
270267
return documents;

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ public final class CassandraVectorStoreConfig implements AutoCloseable {
9090

9191
final boolean disallowSchemaChanges;
9292

93+
// TODO: Remove this flag as the document no longer holds embeddings.
94+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
9395
final boolean returnEmbeddings;
9496

9597
final DocumentIdTranslator documentIdTranslator;

vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,12 @@ void addAndSearch() {
121121

122122
List<Document> documents = documents();
123123
store.add(documents);
124-
for (Document d : documents) {
125-
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
126-
e -> assertThat(e).isNotEmpty());
127-
}
128124

129125
List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
130126

131127
assertThat(results).hasSize(1);
132128
Document resultDoc = results.get(0);
133129
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
134-
assertThat(resultDoc.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNull(),
135-
e -> assertThat(e).isEmpty());
136130

137131
assertThat(resultDoc.getContent()).contains(
138132
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
@@ -158,17 +152,12 @@ void addAndSearchReturnEmbeddings() {
158152
try (CassandraVectorStore store = createTestStore(context, builder)) {
159153
List<Document> documents = documents();
160154
store.add(documents);
161-
for (Document d : documents) {
162-
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
163-
e -> assertThat(e).isNotEmpty());
164-
}
165155

166156
List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
167157

168158
assertThat(results).hasSize(1);
169159
Document resultDoc = results.get(0);
170160
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
171-
assertThat(resultDoc.getEmbedding()).isNotEmpty();
172161

173162
assertThat(resultDoc.getContent()).contains(
174163
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");

vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ public void doAdd(@NonNull List<Document> documents) {
145145
List<String> contents = new ArrayList<>();
146146
List<float[]> embeddings = new ArrayList<>();
147147

148-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
148+
List<float[]> documentEmbeddings = this.embeddingModel.embed(documents,
149+
EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
149150

150151
for (Document document : documents) {
151152
ids.add(document.getId());
152153
metadatas.add(document.getMetadata());
153154
contents.add(document.getContent());
154-
document.setEmbedding(document.getEmbedding());
155-
embeddings.add(document.getEmbedding());
155+
embeddings.add(documentEmbeddings.get(documents.indexOf(document)));
156156
}
157157

158158
this.chromaApi.upsertEmbeddings(this.collectionId,
@@ -193,9 +193,7 @@ public Optional<Boolean> doDelete(@NonNull List<String> idList) {
193193
metadata = new HashMap<>();
194194
}
195195
metadata.put(DISTANCE_FIELD_NAME, distance);
196-
Document document = new Document(id, content, metadata);
197-
document.setEmbedding(chromaEmbedding.embedding());
198-
responseDocuments.add(document);
196+
responseDocuments.add(new Document(id, content, metadata));
199197
}
200198
}
201199

0 commit comments

Comments
 (0)