Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 35 additions & 87 deletions langchain_ydb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ class YDBSettings:
database (str) : Database name to find the table. Defaults to '/local'.
table (str) : Table name to operate on. Defaults to 'ydb_langchain_store'.
column_map (Dict) : Column type map to project column name onto langchain
semantics. Must have keys: `text`, `id`, `vector`,
must be same size to number of columns. For example:
semantics. Must have keys: `id`, `document`, `embedding`,
`metadata`, must be same size to number of columns.
For example:
.. code-block:: python

{
Expand All @@ -81,8 +82,6 @@ class YDBSettings:
Default is 128.
drop_existing_table (bool) : Flag to drop existing table while init.
Defaults to False.
vector_pass_as_bytes (bool) : Flag to pass vectors as bytes to YDB.
Defaults to True.
"""

host: str = "localhost"
Expand All @@ -107,7 +106,6 @@ class YDBSettings:
index_config_clusters: int = 128

drop_existing_table: bool = False
vector_pass_as_bytes: bool = True


class YDB(VectorStore):
Expand Down Expand Up @@ -171,7 +169,13 @@ def __init__(

self._execute_query(self._prepare_scheme_query(), ddl=True)

self._insert_query = self._prepare_insert_query()
self._bulk_upsert_type = (
ydb.BulkUpsertColumns()
.add_column(self.config.column_map["id"], ydb.PrimitiveType.Utf8)
.add_column(self.config.column_map["document"], ydb.PrimitiveType.Utf8)
.add_column(self.config.column_map["embedding"], ydb.PrimitiveType.String)
.add_column(self.config.column_map["metadata"], ydb.PrimitiveType.Json)
)

self._add_index_query = self._prepare_add_index_query()

Expand All @@ -180,23 +184,18 @@ def embeddings(self) -> Optional[Embeddings]:
"""Access the query embedding object if available."""
return self.embedding_function

def _convert_vector_to_bytes_if_needed(
def _convert_vector_to_bytes(
self, vector: list[float]
) -> bytes | list[float]:
if self.config.vector_pass_as_bytes:
b = struct.pack("f" * len(vector), *vector)
return b + b'\x01'
return vector

def _get_vector_type(self) -> str:
if self.config.vector_pass_as_bytes:
return "String"
return "List<Float>"

def _get_sdk_vector_type(self) -> ydb.PrimitiveType | ydb.ListType:
if self.config.vector_pass_as_bytes:
return ydb.PrimitiveType.String
return ydb.ListType(ydb.PrimitiveType.Float)
) -> bytes:
b = struct.pack("f" * len(vector), *vector)
return b + b'\x01'

def _bulk_upsert_documents(self, documents: list[dict]) -> None:
self.connection.bulk_upsert(
self.config.table,
documents,
self._bulk_upsert_type,
)

def _execute_query(
self,
Expand Down Expand Up @@ -293,32 +292,6 @@ def update_vector_index_if_needed(self) -> None:
],
)

def _prepare_insert_query(self) -> str:
embedding_select = "embedding" if self.config.vector_pass_as_bytes \
else "Untag(Knn::ToBinaryStringFloat(embedding), 'FloatVector')"

return f"""
DECLARE $documents AS List<Struct<
id: Utf8,
document: Utf8,
embedding: {self._get_vector_type()},
metadata: Json>>;

UPSERT INTO `{self.config.table}`
(
{self.config.column_map["id"]},
{self.config.column_map["document"]},
{self.config.column_map["embedding"]},
{self.config.column_map["metadata"]}
)
SELECT
id,
document,
{embedding_select},
metadata
FROM AS_TABLE($documents);
"""

def _prepare_search_query(
self,
k: int,
Expand All @@ -344,21 +317,10 @@ def _prepare_search_query(
if self.config.index_enabled:
view_index = f"VIEW {self.config.index_name}"

if self.config.vector_pass_as_bytes:
declare_embedding = """
DECLARE $embedding as String;

$TargetEmbedding = $embedding;
"""
else:
declare_embedding = """
DECLARE $embedding as List<Float>;

$TargetEmbedding = Knn::ToBinaryStringFloat($embedding);
"""

return f"""
{declare_embedding}
DECLARE $embedding as String;

$TargetEmbedding = $embedding;

SELECT
{self.config.column_map["id"]} as id,
Expand Down Expand Up @@ -415,16 +377,6 @@ def add_texts(

metadatas = metadatas if metadatas else [{} for _ in range(len(texts_))]

# Define struct type with proper member fields
document_struct_type = ydb.StructType()
document_struct_type.add_member('id', ydb.PrimitiveType.Utf8)
document_struct_type.add_member('document', ydb.PrimitiveType.Utf8)
document_struct_type.add_member(
'embedding',
self._get_sdk_vector_type()
)
document_struct_type.add_member('metadata', ydb.PrimitiveType.Json)

# Process in batches
batch_ranges = range(0, len(texts_), batch_size)
for i in self.pgbar(
Expand All @@ -448,20 +400,16 @@ def add_texts(
):
# Use dictionary format for struct values - YDB will convert them
document = {
'id': doc_id,
'document': doc_text,
'embedding': self._convert_vector_to_bytes_if_needed(doc_embedding),
'metadata': json.dumps(doc_metadata)
self.config.column_map["id"]: doc_id,
self.config.column_map["document"]: doc_text,
self.config.column_map["embedding"]: self._convert_vector_to_bytes(
doc_embedding,
),
self.config.column_map["metadata"]: json.dumps(doc_metadata)
}
documents.append(document)

# Execute the batch insert
self._execute_query(
self._insert_query,
{
"$documents": (documents, ydb.ListType(document_struct_type))
},
)
self._bulk_upsert_documents(documents)

self.update_vector_index_if_needed()

Expand Down Expand Up @@ -551,8 +499,8 @@ def similarity_search_by_vector(
query,
params={
"$embedding": (
self._convert_vector_to_bytes_if_needed(embedding),
self._get_sdk_vector_type()
self._convert_vector_to_bytes(embedding),
ydb.PrimitiveType.String,
)
},
)
Expand Down Expand Up @@ -583,8 +531,8 @@ def similarity_search_with_score(
query,
params={
"$embedding": (
self._convert_vector_to_bytes_if_needed(embedding),
self._get_sdk_vector_type()
self._convert_vector_to_bytes(embedding),
ydb.PrimitiveType.String,
)
},
)
Expand Down
8 changes: 2 additions & 6 deletions tests/test_vectorestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from .fake_embeddings import ConsistentFakeEmbeddings


@pytest.mark.parametrize("vector_pass_as_bytes", [True, False])
def test_ydb(vector_pass_as_bytes: bool) -> None:
def test_ydb() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
config = YDBSettings(
drop_existing_table=True,
vector_pass_as_bytes=vector_pass_as_bytes,
)
config.table = "test_ydb"
docsearch = YDB.from_texts(texts, ConsistentFakeEmbeddings(), config=config)
Expand All @@ -24,13 +22,11 @@ def test_ydb(vector_pass_as_bytes: bool) -> None:


@pytest.mark.asyncio
@pytest.mark.parametrize("vector_pass_as_bytes", [True, False])
async def test_ydb_async(vector_pass_as_bytes: bool) -> None:
async def test_ydb_async() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
config = YDBSettings(
drop_existing_table=True,
vector_pass_as_bytes=vector_pass_as_bytes,
)
config.table = "test_ydb_async"
docsearch = YDB.from_texts(
Expand Down
Loading