Skip to content

Commit 83d7f9a

Browse files
murathany7vgvoleg
andauthored
Add support for batch document insertion to YDB (#7)
* Add batch document insertion functionality * Update test_vectorestore.py * default size changed and using more efficient method * Update test_vectorestore.py * Update vectorstores.py --------- Co-authored-by: Oleg Ovcharuk <[email protected]>
1 parent cad6536 commit 83d7f9a

File tree

3 files changed

+152
-23
lines changed

3 files changed

+152
-23
lines changed

docker/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
version: "3.3"
22
services:
33
ydb:
4-
image: ydbplatform/local-ydb:trunk
4+
image: ydbplatform/local-ydb:24.3.13.12
55
restart: always
66
ports:
77
- "2136:2136"

langchain_ydb/vectorstores.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,11 @@ def _escape_str(self, text: str) -> str:
218218

219219
def _prepare_insert_query(self) -> str:
220220
return f"""
221-
DECLARE $id AS Utf8;
222-
DECLARE $document as Utf8;
223-
DECLARE $embedding as List<Float>;
224-
DECLARE $metadata as Json;
221+
DECLARE $documents AS List<Struct<
222+
id: Utf8,
223+
document: Utf8,
224+
embedding: List<Float>,
225+
metadata: Json>>;
225226
226227
UPSERT INTO `{self.config.table}`
227228
(
@@ -230,13 +231,12 @@ def _prepare_insert_query(self) -> str:
230231
{self.config.column_map["embedding"]},
231232
{self.config.column_map["metadata"]}
232233
)
233-
VALUES
234-
(
235-
$id,
236-
$document,
237-
Untag(Knn::ToBinaryStringFloat($embedding), "FloatVector"),
238-
$metadata
239-
);
234+
SELECT
235+
id,
236+
document,
237+
Untag(Knn::ToBinaryStringFloat(embedding), "FloatVector"),
238+
metadata
239+
FROM AS_TABLE($documents);
240240
"""
241241

242242
def _prepare_search_query(
@@ -285,6 +285,7 @@ def add_texts(
285285
metadatas: Optional[list[dict]] = None,
286286
*,
287287
ids: Optional[list[str]] = None,
288+
batch_size: int = 32,
288289
**kwargs: Any,
289290
) -> list[str]:
290291
"""Run more texts through the embeddings and add to the vectorstore.
@@ -293,6 +294,7 @@ def add_texts(
293294
texts: Iterable of strings to add to the vectorstore.
294295
metadatas: Optional list of metadatas associated with the texts.
295296
ids: Optional list of IDs associated with the texts.
297+
batch_size: Number of texts to process in a single batch. Defaults to 32.
296298
**kwargs: vectorstore specific parameters.
297299
One of the kwargs should be `ids` which is a list of ids
298300
associated with the texts.
@@ -315,20 +317,52 @@ def add_texts(
315317
metadatas = metadatas if metadatas else [{} for _ in range(len(texts_))]
316318

317319
ydb = self._ydb_lib
318-
319-
for id, text, metadata in self.pgbar(
320-
zip(ids, texts, metadatas),
321-
desc="Inserting data...",
322-
total=len(ids),
320+
321+
# Define struct type with proper member fields
322+
document_struct_type = ydb.StructType()
323+
document_struct_type.add_member('id', ydb.PrimitiveType.Utf8)
324+
document_struct_type.add_member('document', ydb.PrimitiveType.Utf8)
325+
document_struct_type.add_member(
326+
'embedding',
327+
ydb.ListType(ydb.PrimitiveType.Float)
328+
)
329+
document_struct_type.add_member('metadata', ydb.PrimitiveType.Json)
330+
331+
# Process in batches
332+
batch_ranges = range(0, len(texts_), batch_size)
333+
for i in self.pgbar(
334+
batch_ranges,
335+
desc="Processing batches...",
336+
total=len(batch_ranges)
323337
):
324-
embedding = self.embedding_function.embed_query(text)
338+
batch_texts = texts_[i:i+batch_size]
339+
batch_ids = ids[i:i+batch_size]
340+
batch_metadatas = metadatas[i:i+batch_size]
341+
342+
# Generate embeddings for the batch
343+
embeddings = self.embedding_function.embed_documents(
344+
batch_texts, # type: ignore
345+
)
346+
347+
# Create a list of document structs
348+
documents = []
349+
for doc_id, doc_text, doc_embedding, doc_metadata in zip(
350+
batch_ids, batch_texts, embeddings, batch_metadatas
351+
):
352+
# Use dictionary format for struct values - YDB will convert them
353+
document = {
354+
'id': doc_id,
355+
'document': doc_text,
356+
'embedding': doc_embedding,
357+
'metadata': json.dumps(doc_metadata)
358+
}
359+
documents.append(document)
360+
361+
# Execute the batch insert
325362
self._execute_query(
326363
self._insert_query,
327364
{
328-
"$id": id,
329-
"$document": text,
330-
"$embedding": (embedding, ydb.ListType(ydb.PrimitiveType.Float)),
331-
"$metadata": (json.dumps(metadata), ydb.PrimitiveType.Json),
365+
"$documents": (documents, ydb.ListType(document_struct_type))
332366
},
333367
)
334368

@@ -343,6 +377,7 @@ def from_texts(
343377
*,
344378
config: Optional[YDBSettings] = None,
345379
ids: Optional[list[str]] = None,
380+
batch_size: int = 32,
346381
**kwargs: Any,
347382
) -> YDB:
348383
"""Return YDB VectorStore initialized from texts and embeddings.
@@ -353,13 +388,14 @@ def from_texts(
353388
metadatas: Optional list of metadatas associated with the texts.
354389
Default is None.
355390
ids: Optional list of IDs associated with the texts.
391+
batch_size: Number of texts to process in a single batch. Defaults to 32.
356392
kwargs: Additional keyword arguments.
357393
358394
Returns:
359395
VectorStore: VectorStore initialized from texts and embeddings.
360396
"""
361397
vs = cls(embedding, config, **kwargs)
362-
vs.add_texts(texts=texts, metadatas=metadatas, ids=ids)
398+
vs.add_texts(texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size)
363399
return vs
364400

365401
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:

tests/test_vectorestore.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,96 @@ def test_search_from_retriever_interface_with_filter() -> None:
328328
assert output == [Document(page_content="bar", metadata={"page": "1"})]
329329

330330
docsearch.drop()
331+
332+
333+
@pytest.mark.parametrize("n", [10, 50, 100])
334+
def test_batch_insertion(n: int) -> None:
335+
"""Test batch insertion with different document counts."""
336+
# Create documents
337+
texts = [f"text_{i}" for i in range(n)]
338+
metadatas = [{"index": str(i)} for i in range(n)]
339+
340+
# Create vectorstore
341+
config = YDBSettings(drop_existing_table=True)
342+
config.table = f"test_ydb_batch_{n}"
343+
docsearch = YDB.from_texts(
344+
texts=texts,
345+
embedding=ConsistentFakeEmbeddings(),
346+
config=config,
347+
metadatas=metadatas,
348+
)
349+
350+
# Verify total count matches expected
351+
all_results = docsearch.similarity_search("text", k=n + 1)
352+
assert len(all_results) == n
353+
354+
# Clean up
355+
docsearch.drop()
356+
357+
@pytest.mark.parametrize("n,batch_size", [(25, None), (50, 10), (100, 50)])
358+
def test_batch_insertion_with_add_texts(n: int, batch_size: int) -> None:
359+
"""Test add_texts with different document counts and batch sizes."""
360+
# Setup
361+
config = YDBSettings(drop_existing_table=True)
362+
config.table = f"test_ydb_add_texts_batch_{n}_{batch_size}"
363+
docsearch = YDB(
364+
embedding=ConsistentFakeEmbeddings(),
365+
config=config,
366+
)
367+
368+
# Create test data
369+
texts = [f"text_{i}" for i in range(n)]
370+
metadatas = [{"index": str(i)} for i in range(n)]
371+
372+
# Mock the embedding and execute functions to verify batch behavior
373+
with pytest.MonkeyPatch.context() as mp:
374+
# Track batches
375+
processed_batches = []
376+
377+
# Mock embedding function to track batch sizes
378+
def mock_embed_documents(texts):
379+
processed_batches.append(len(texts))
380+
# Return fake embeddings of appropriate length
381+
return [[0.1] * 5 for _ in range(len(texts))]
382+
383+
# Mock execute query to avoid actual database operations
384+
def mock_execute_query(query, params=None, ddl=False):
385+
return None
386+
387+
# Apply mocks
388+
mp.setattr(
389+
docsearch.embedding_function, "embed_documents", mock_embed_documents
390+
)
391+
mp.setattr(docsearch, "_execute_query", mock_execute_query)
392+
393+
# Execute add_texts with specified batch size
394+
kwargs = {}
395+
if batch_size is not None:
396+
kwargs["batch_size"] = batch_size
397+
398+
ids = docsearch.add_texts(
399+
texts=texts,
400+
metadatas=metadatas,
401+
**kwargs
402+
)
403+
404+
# Verify results
405+
assert len(ids) == n # Correct number of IDs returned
406+
407+
# Verify correct batch sizes were used
408+
expected_batch_size = batch_size if batch_size is not None else 32
409+
expected_num_batches = (n + expected_batch_size - 1) // expected_batch_size
410+
411+
assert len(processed_batches) == expected_num_batches
412+
413+
# Verify all texts were processed in total
414+
assert sum(processed_batches) == n
415+
416+
# Verify most batches are of the expected size (except possibly the last one)
417+
for i, batch_size in enumerate(processed_batches):
418+
if i < len(processed_batches) - 1:
419+
# All but the last batch should be full
420+
assert batch_size == expected_batch_size
421+
else:
422+
# Last batch can be smaller
423+
assert batch_size <= expected_batch_size

0 commit comments

Comments
 (0)