diff --git a/langchain_ydb/vectorstores.py b/langchain_ydb/vectorstores.py index 923a13e..ea8c89c 100644 --- a/langchain_ydb/vectorstores.py +++ b/langchain_ydb/vectorstores.py @@ -55,8 +55,9 @@ class YDBSettings: database (str) : Database name to find the table. Defaults to '/local'. table (str) : Table name to operate on. Defaults to 'ydb_langchain_store'. column_map (Dict) : Column type map to project column name onto langchain - semantics. Must have keys: `text`, `id`, `vector`, - must be same size to number of columns. For example: + semantics. Must have keys: `id`, `document`, `embedding`, + `metadata`, must be same size to number of columns. + For example: .. code-block:: python { @@ -81,8 +82,6 @@ class YDBSettings: Default is 128. drop_existing_table (bool) : Flag to drop existing table while init. Defaults to False. - vector_pass_as_bytes (bool) : Flag to pass vectors as bytes to YDB. - Defaults to True. """ host: str = "localhost" @@ -107,7 +106,6 @@ class YDBSettings: index_config_clusters: int = 128 drop_existing_table: bool = False - vector_pass_as_bytes: bool = True class YDB(VectorStore): @@ -171,7 +169,13 @@ def __init__( self._execute_query(self._prepare_scheme_query(), ddl=True) - self._insert_query = self._prepare_insert_query() + self._bulk_upsert_type = ( + ydb.BulkUpsertColumns() + .add_column(self.config.column_map["id"], ydb.PrimitiveType.Utf8) + .add_column(self.config.column_map["document"], ydb.PrimitiveType.Utf8) + .add_column(self.config.column_map["embedding"], ydb.PrimitiveType.String) + .add_column(self.config.column_map["metadata"], ydb.PrimitiveType.Json) + ) self._add_index_query = self._prepare_add_index_query() @@ -180,23 +184,18 @@ def embeddings(self) -> Optional[Embeddings]: """Access the query embedding object if available.""" return self.embedding_function - def _convert_vector_to_bytes_if_needed( + def _convert_vector_to_bytes( self, vector: list[float] - ) -> bytes | list[float]: - if self.config.vector_pass_as_bytes: - b = struct.pack("f" * len(vector), *vector) - return b + b'\x01' - return vector - - def _get_vector_type(self) -> str: - if self.config.vector_pass_as_bytes: - return "String" - return "List" - - def _get_sdk_vector_type(self) -> ydb.PrimitiveType | ydb.ListType: - if self.config.vector_pass_as_bytes: - return ydb.PrimitiveType.String - return ydb.ListType(ydb.PrimitiveType.Float) + ) -> bytes: + b = struct.pack("f" * len(vector), *vector) + return b + b'\x01' + + def _bulk_upsert_documents(self, documents: list[dict]) -> None: + self.connection.bulk_upsert( + self.config.table, + documents, + self._bulk_upsert_type, + ) def _execute_query( self, @@ -293,32 +292,6 @@ def update_vector_index_if_needed(self) -> None: ], ) - def _prepare_insert_query(self) -> str: - embedding_select = "embedding" if self.config.vector_pass_as_bytes \ - else "Untag(Knn::ToBinaryStringFloat(embedding), 'FloatVector')" - - return f""" - DECLARE $documents AS List>; - - UPSERT INTO `{self.config.table}` - ( - {self.config.column_map["id"]}, - {self.config.column_map["document"]}, - {self.config.column_map["embedding"]}, - {self.config.column_map["metadata"]} - ) - SELECT - id, - document, - {embedding_select}, - metadata - FROM AS_TABLE($documents); - """ - def _prepare_search_query( self, k: int, @@ -344,21 +317,10 @@ def _prepare_search_query( if self.config.index_enabled: view_index = f"VIEW {self.config.index_name}" - if self.config.vector_pass_as_bytes: - declare_embedding = """ - DECLARE $embedding as String; - - $TargetEmbedding = $embedding; - """ - else: - declare_embedding = """ - DECLARE $embedding as List; - - $TargetEmbedding = Knn::ToBinaryStringFloat($embedding); - """ - return f""" - {declare_embedding} + DECLARE $embedding as String; + + $TargetEmbedding = $embedding; SELECT {self.config.column_map["id"]} as id, @@ -415,16 +377,6 @@ def add_texts( metadatas = metadatas if metadatas else [{} for _ in range(len(texts_))] - # Define struct type with proper member fields - document_struct_type = ydb.StructType() - document_struct_type.add_member('id', ydb.PrimitiveType.Utf8) - document_struct_type.add_member('document', ydb.PrimitiveType.Utf8) - document_struct_type.add_member( - 'embedding', - self._get_sdk_vector_type() - ) - document_struct_type.add_member('metadata', ydb.PrimitiveType.Json) - # Process in batches batch_ranges = range(0, len(texts_), batch_size) for i in self.pgbar( @@ -448,20 +400,16 @@ def add_texts( ): # Use dictionary format for struct values - YDB will convert them document = { - 'id': doc_id, - 'document': doc_text, - 'embedding': self._convert_vector_to_bytes_if_needed(doc_embedding), - 'metadata': json.dumps(doc_metadata) + self.config.column_map["id"]: doc_id, + self.config.column_map["document"]: doc_text, + self.config.column_map["embedding"]: self._convert_vector_to_bytes( + doc_embedding, + ), + self.config.column_map["metadata"]: json.dumps(doc_metadata) } documents.append(document) - # Execute the batch insert - self._execute_query( - self._insert_query, - { - "$documents": (documents, ydb.ListType(document_struct_type)) - }, - ) + self._bulk_upsert_documents(documents) self.update_vector_index_if_needed() @@ -551,8 +499,8 @@ def similarity_search_by_vector( query, params={ "$embedding": ( - self._convert_vector_to_bytes_if_needed(embedding), - self._get_sdk_vector_type() + self._convert_vector_to_bytes(embedding), + ydb.PrimitiveType.String, ) }, ) @@ -583,8 +531,8 @@ def similarity_search_with_score( query, params={ "$embedding": ( - self._convert_vector_to_bytes_if_needed(embedding), - self._get_sdk_vector_type() + self._convert_vector_to_bytes(embedding), + ydb.PrimitiveType.String, ) }, ) diff --git a/tests/test_vectorestore.py b/tests/test_vectorestore.py index 797b64d..80b2cb3 100644 --- a/tests/test_vectorestore.py +++ b/tests/test_vectorestore.py @@ -8,13 +8,11 @@ from .fake_embeddings import ConsistentFakeEmbeddings -@pytest.mark.parametrize("vector_pass_as_bytes", [True, False]) -def test_ydb(vector_pass_as_bytes: bool) -> None: +def test_ydb() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] config = YDBSettings( drop_existing_table=True, - vector_pass_as_bytes=vector_pass_as_bytes, ) config.table = "test_ydb" docsearch = YDB.from_texts(texts, ConsistentFakeEmbeddings(), config=config) @@ -24,13 +22,11 @@ def test_ydb(vector_pass_as_bytes: bool) -> None: @pytest.mark.asyncio -@pytest.mark.parametrize("vector_pass_as_bytes", [True, False]) -async def test_ydb_async(vector_pass_as_bytes: bool) -> None: +async def test_ydb_async() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] config = YDBSettings( drop_existing_table=True, - vector_pass_as_bytes=vector_pass_as_bytes, ) config.table = "test_ydb_async" docsearch = YDB.from_texts(