From d8f5dc310fde4cc1873e47b47060925a232f5125 Mon Sep 17 00:00:00 2001 From: Orhan Kislal Date: Tue, 30 Sep 2025 14:57:55 +0300 Subject: [PATCH 1/2] Add Async class for Azure PostgreSQL integration --- .../vector_stores/azure_postgres/__init__.py | 2 + .../azure_postgres/async_base.py | 304 +++++++ .../vector_stores/azure_postgres/base.py | 135 +-- .../azure_postgres/common/__init__.py | 6 + .../azure_postgres/common/_base.py | 838 +++++++++++++++++- .../tests/common/test_connection.py | 236 ++++- .../tests/conftest.py | 248 +++++- .../tests/llama_index/conftest.py | 132 ++- .../tests/llama_index/test_vectorstore.py | 277 +++++- .../uv.lock | 2 +- 10 files changed, 1992 insertions(+), 188 deletions(-) create mode 100644 llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/async_base.py diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/__init__.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/__init__.py index 491ecda7c5..1cbbb1f378 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/__init__.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/__init__.py @@ -1,7 +1,9 @@ """Common utilities and models for Azure Database for PostgreSQL operations.""" +from .async_base import AsyncAzurePGVectorStore from .base import AzurePGVectorStore __all__ = [ "AzurePGVectorStore", + "AsyncAzurePGVectorStore", ] diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/async_base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/async_base.py new file mode 100644 index 0000000000..12fe1b71c7 --- /dev/null +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/async_base.py @@ -0,0 +1,304 @@ +"""VectorStore integration for Azure Database for PostgreSQL using LlamaIndex.""" + +import sys +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any, Optional + +import numpy as np +from pgvector.psycopg import register_vector_async # type: ignore[import-untyped] +from psycopg import AsyncConnection, sql +from psycopg.rows import dict_row +from psycopg.types.json import Jsonb + +from llama_index.core.schema import BaseNode, MetadataMode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryResult, +) +from llama_index.core.vector_stores.utils import ( + metadata_dict_to_node, + node_to_metadata_dict, +) + +from .common import ( + Algorithm, + AsyncAzurePGConnectionPool, + AsyncBaseAzurePGVectorStore, + _table_row_to_node, + metadata_filters_to_sql, +) + +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override + + +class AsyncAzurePGVectorStore(BasePydanticVectorStore, AsyncBaseAzurePGVectorStore): + """Azure PostgreSQL vector store for LlamaIndex.""" + + stores_text: bool = True + metadata_columns: str | None = "metadata" + + @classmethod + def class_name(cls) -> str: + """Return the class name for this vector store.""" + return "AzurePGVectorStore" + + @property + def client(self) -> None: + """Return the client property (not used for AzurePGVectorStore).""" + return + + @asynccontextmanager + async def _connection(self) -> AsyncGenerator[AsyncConnection, None]: + async with self.connection_pool.connection() as conn: + yield conn + + @override + @classmethod + def from_params( + cls, + connection_pool: AsyncAzurePGConnectionPool, + schema_name: str = "public", + table_name: str = "llamaindex_vectors", + embed_dim: int = 1536, + embedding_index: Algorithm | None = None, + ) -> "AsyncAzurePGVectorStore": + """Create an AsyncAzurePGVectorStore from connection and configuration parameters.""" + return cls( + connection_pool=connection_pool, + schema_name=schema_name, + table_name=table_name, + embed_dim=embed_dim, + embedding_index=embedding_index, + ) + + def _get_insert_sql_dict( + self, node: BaseNode, on_conflict_update: bool + ) -> tuple[sql.SQL, dict[str, Any]]: + """Get the SQL command and dictionary for inserting a node.""" + if on_conflict_update: + update = sql.SQL( + """ + UPDATE SET + {content_col} = EXCLUDED.{content_col}, + {embedding_col} = EXCLUDED.{embedding_col}, + {metadata_col} = EXCLUDED.{metadata_col} + """ + ).format( + id_col=sql.Identifier(self.id_column), + content_col=sql.Identifier(self.content_column), + embedding_col=sql.Identifier(self.embedding_column), + metadata_col=sql.Identifier(self.metadata_columns), + ) + else: + update = sql.SQL("nothing") + insert_sql = sql.SQL( + """ + INSERT INTO {schema}.{table} ({id_col}, {content_col}, {embedding_col}, {metadata_col}) + VALUES (%(id)s, %(content)s, %(embedding)s, %(metadata)s) + ON CONFLICT ({id_col}) DO {update} + """ + ).format( + schema=sql.Identifier(self.schema_name), + table=sql.Identifier(self.table_name), + id_col=sql.Identifier(self.id_column), + content_col=sql.Identifier(self.content_column), + embedding_col=sql.Identifier(self.embedding_column), + metadata_col=sql.Identifier(self.metadata_columns), + update=update, + ) + + return ( + insert_sql, + { + "id": node.node_id, + "content": node.get_content(metadata_mode=MetadataMode.NONE), + "embedding": np.array(node.get_embedding(), dtype=np.float32), + "metadata": Jsonb(node_to_metadata_dict(node)), + }, + ) + + @override + async def async_add( + self, + nodes: list[BaseNode], + **kwargs: Any, + ) -> list[str]: + """Asynchronously add nodes to vector store.""" + on_conflict_update = bool(kwargs.pop("on_conflict_update", None)) + ids = [] + async with self._connection() as conn: + await register_vector_async(conn) + async with conn.cursor(row_factory=dict_row) as cursor: + for node in nodes: + ids.append(node.node_id) + insert_sql, insert_dict = self._get_insert_sql_dict( + node, on_conflict_update=on_conflict_update + ) + await cursor.execute(insert_sql, insert_dict) + return ids + + @override + async def aquery( + self, + query: VectorStoreQuery, + **kwargs: Any, + ) -> VectorStoreQueryResult: + """Asynchronously query the vector store.""" + results = await self._similarity_search_by_vector_with_distance( + embedding=query.query_embedding, + k=query.similarity_top_k, + filter_expression=metadata_filters_to_sql(query.filters), + **kwargs, + ) + nodes = [] + similarities = [] + ids = [] + for row in results: + node = metadata_dict_to_node(row[0]["metadata"], text=row[0]["content"]) + nodes.append(node) + similarities.append(row[1]) + ids.append(row[0]["id"]) + + return VectorStoreQueryResult( + nodes=nodes, + similarities=similarities, + ids=ids, + ) + + @override + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Delete a node from the vector store by reference document ID. + + Args: + ref_doc_id: The reference document ID to delete. + **delete_kwargs: Additional keyword arguments. + """ + async with self.connection_pool.connection() as conn: + await register_vector_async(conn) + async with conn.cursor(row_factory=dict_row) as cursor: + delete_sql = sql.SQL( + "DELETE FROM {table} WHERE metadata ->> 'doc_id' = %s" + ).format(table=sql.Identifier(self.schema_name, self.table_name)) + await cursor.execute(delete_sql, (ref_doc_id,)) + + @override + async def adelete_nodes( + self, + node_ids: Optional[list[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + """Delete nodes from the vector store by node IDs or filters. + + Args: + node_ids: Optional list of node IDs to delete. + filters: Optional MetadataFilters to filter nodes for deletion. + **delete_kwargs: Additional keyword arguments. + """ + if not node_ids: + return + + await self._delete_rows_from_table( + ids=node_ids, filters=metadata_filters_to_sql(filters), **delete_kwargs + ) + + @override + async def aclear(self) -> None: + """Clear all data from the vector store table.""" + async with self.connection_pool.connection() as conn: + await register_vector_async(conn) + async with conn.cursor(row_factory=dict_row) as cursor: + stmt = sql.SQL("TRUNCATE TABLE {schema}.{table}").format( + schema=sql.Identifier(self.schema_name), + table=sql.Identifier(self.table_name), + ) + await cursor.execute(stmt) + await conn.commit() + + @override + async def aget_nodes( + self, + node_ids: Optional[list[str]] = None, + filters: Optional[MetadataFilters] = None, + **kwargs: Any, + ) -> list[BaseNode]: + """Retrieve nodes by IDs or filters. + + Args: + node_ids: Optional list of node IDs to retrieve. + filters: Optional MetadataFilters to filter nodes. + **kwargs: Additional keyword arguments. + + Returns: + List of BaseNode objects matching the criteria. + """ + # TODO: Implement filter handling + documents = await self._get_by_ids(node_ids) + nodes = [] + for doc in documents: + node = _table_row_to_node(doc) + nodes.append(node) + + return nodes + + @override + def add( + self, + nodes: list[BaseNode], + **kwargs: Any, + ) -> list[str]: + """Not implemented for AsyncAzurePGVectorStore; use AzurePGVectorStore instead.""" + raise NotImplementedError( + "Add interface is not implemented for AsyncAzurePGVectorStore: use AzurePGVectorStore, instead." + ) + + @override + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Not implemented for AsyncAzurePGVectorStore; use AzurePGVectorStore instead.""" + raise NotImplementedError( + "Delete interface is not implemented for AsyncAzurePGVectorStore: use AzurePGVectorStore, instead." + ) + + @override + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + """Not implemented for AsyncAzurePGVectorStore; use AzurePGVectorStore instead.""" + raise NotImplementedError( + "Query interface is not implemented for AsyncAzurePGVectorStore: use AzurePGVectorStore, instead." + ) + + @override + def get_nodes( + self, + node_ids: Optional[list[str]] = None, + filters: Optional[MetadataFilters] = None, + **kwargs: Any, + ) -> list[BaseNode]: + """Not implemented for AsyncAzurePGVectorStore; use AzurePGVectorStore instead.""" + raise NotImplementedError( + "get_nodes interface is not implemented for AsyncAzurePGVectorStore: use AzurePGVectorStore, instead." + ) + + @override + def clear(self) -> None: + """Not implemented for AsyncAzurePGVectorStore; use AzurePGVectorStore instead.""" + raise NotImplementedError( + "clear interface is not implemented for AsyncAzurePGVectorStore: use AzurePGVectorStore, instead." + ) + + @override + def delete_nodes( + self, + node_ids: Optional[list[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + """Not implemented for AsyncAzurePGVectorStore; use AzurePGVectorStore instead.""" + raise NotImplementedError( + "delete_nodes interface is not implemented for AsyncAzurePGVectorStore: use AzurePGVectorStore, instead." + ) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/base.py index f7e1979fd2..e123606e40 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/base.py @@ -1,7 +1,7 @@ """VectorStore integration for Azure Database for PostgreSQL using LlamaIndex.""" import sys -from typing import Any, Optional, Union +from typing import Any, Optional import numpy as np from pgvector.psycopg import register_vector # type: ignore[import-untyped] @@ -12,8 +12,6 @@ from llama_index.core.schema import BaseNode, MetadataMode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, - FilterOperator, - MetadataFilter, MetadataFilters, VectorStoreQuery, VectorStoreQueryResult, @@ -27,6 +25,8 @@ Algorithm, AzurePGConnectionPool, BaseAzurePGVectorStore, + _table_row_to_node, + metadata_filters_to_sql, ) if sys.version_info < (3, 12): @@ -35,110 +35,6 @@ from typing import override -def metadata_filters_to_sql(filters: Optional[MetadataFilters]) -> sql.SQL: - """Convert LlamaIndex MetadataFilters to a SQL WHERE clause. - - Args: - filters: Optional MetadataFilters object. - - Returns: - sql.SQL: SQL WHERE clause representing the filters. - """ - if not filters or not filters.filters: - return sql.SQL("TRUE") - - def _filter_to_sql(filter_item: Union[MetadataFilter, MetadataFilters]) -> sql.SQL: - """Recursively convert MetadataFilter or MetadataFilters to SQL.""" - if isinstance(filter_item, MetadataFilters): - # Handle nested MetadataFilters - if not filter_item.filters: - return sql.SQL("TRUE") - - filter_sqls = [_filter_to_sql(f) for f in filter_item.filters] - if filter_item.condition.lower() == "and": - return sql.SQL("({})").format(sql.SQL(" AND ").join(filter_sqls)) - elif filter_item.condition.lower() == "or": - return sql.SQL("({})").format(sql.SQL(" OR ").join(filter_sqls)) - else: # NOT - if len(filter_sqls) == 1: - return sql.SQL("NOT ({})").format(filter_sqls[0]) - else: - # For multiple filters with NOT, apply NOT to the AND of all filters - return sql.SQL("NOT ({})").format( - sql.SQL(" AND ").join(filter_sqls) - ) - - elif isinstance(filter_item, MetadataFilter): - # Handle individual MetadataFilter - key = filter_item.key - value = filter_item.value - operator = filter_item.operator - - # Use JSONB path for metadata column - column_ref = sql.SQL("metadata ->> {}").format(sql.Literal(key)) - - if operator == FilterOperator.EQ: - return sql.SQL("{} = {}").format(column_ref, sql.Literal(str(value))) - elif operator == FilterOperator.NE: - return sql.SQL("{} != {}").format(column_ref, sql.Literal(str(value))) - elif operator == FilterOperator.GT: - return sql.SQL("({}) > {}").format(column_ref, sql.Literal(value)) - elif operator == FilterOperator.LT: - return sql.SQL("({}) < {}").format(column_ref, sql.Literal(value)) - elif operator == FilterOperator.GTE: - return sql.SQL("({}) >= {}").format(column_ref, sql.Literal(value)) - elif operator == FilterOperator.LTE: - return sql.SQL("({}) <= {}").format(column_ref, sql.Literal(value)) - elif operator == FilterOperator.IN: - if isinstance(value, list): - values = sql.SQL(", ").join([sql.Literal(str(v)) for v in value]) - return sql.SQL("{} IN ({})").format(column_ref, values) - else: - return sql.SQL("{} = {}").format( - column_ref, sql.Literal(str(value)) - ) - elif operator == FilterOperator.NIN: - if isinstance(value, list): - values = sql.SQL(", ").join([sql.Literal(str(v)) for v in value]) - return sql.SQL("{} NOT IN ({})").format(column_ref, values) - else: - return sql.SQL("{} != {}").format( - column_ref, sql.Literal(str(value)) - ) - elif operator == FilterOperator.CONTAINS: - # For JSONB array contains - return sql.SQL("metadata -> {} ? {}").format( - sql.Literal(key), sql.Literal(str(value)) - ) - elif operator == FilterOperator.TEXT_MATCH: - return sql.SQL("{} LIKE {}").format( - column_ref, sql.Literal(f"%{value}%") - ) - elif operator == FilterOperator.TEXT_MATCH_INSENSITIVE: - return sql.SQL("{} ILIKE {}").format( - column_ref, sql.Literal(f"%{value}%") - ) - elif operator == FilterOperator.IS_EMPTY: - return sql.SQL("({} IS NULL OR {} = '')").format(column_ref, column_ref) - else: - # Default to equality for unsupported operators - return sql.SQL("{} = {}").format(column_ref, sql.Literal(str(value))) - - return sql.SQL("TRUE") - - filter_sqls = [_filter_to_sql(f) for f in filters.filters] - - if filters.condition.lower() == "and": - return sql.SQL(" AND ").join(filter_sqls) - elif filters.condition.lower() == "or": - return sql.SQL(" OR ").join(filter_sqls) - else: # NOT - if len(filter_sqls) == 1: - return sql.SQL("NOT ({})").format(filter_sqls[0]) - else: - return sql.SQL("NOT ({})").format(sql.SQL(" AND ").join(filter_sqls)) - - class AzurePGVectorStore(BasePydanticVectorStore, BaseAzurePGVectorStore): """Azure PostgreSQL vector store for LlamaIndex.""" @@ -174,29 +70,6 @@ def from_params( embedding_index=embedding_index, ) - def _table_row_to_node(self, row: dict[str, Any]) -> BaseNode: - """Convert a table row dictionary to a BaseNode object.""" - metadata = row.get("metadata") - if metadata is None: - raise ValueError("Metadata not found in row data.") - - node = metadata_dict_to_node(metadata, text=row.get("content")) - # Convert UUID to string if needed - node_id = row.get("id") - if node_id is not None: - node.node_id = str(node_id) - embedding = row.get("embedding") - - if isinstance(embedding, str): - embedding = row.get("embedding").strip("[]").split(",") - node.embedding = list(map(float, embedding)) - elif embedding is not None: - node.embedding = embedding - else: - raise ValueError("Missing embedding value") - - return node - def _get_insert_sql_dict( self, node: BaseNode, on_conflict_update: bool ) -> tuple[sql.SQL, dict[str, Any]]: @@ -370,7 +243,7 @@ def get_nodes( documents = self._get_by_ids(node_ids) nodes = [] for doc in documents: - node = self._table_row_to_node(doc) + node = _table_row_to_node(row=doc) nodes.append(node) return nodes diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/__init__.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/__init__.py index c329d538c9..2c8539af67 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/__init__.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/__init__.py @@ -1,7 +1,10 @@ """Common utilities and models for Azure Database for PostgreSQL operations.""" from ._base import ( + AsyncBaseAzurePGVectorStore, BaseAzurePGVectorStore, + _table_row_to_node, + metadata_filters_to_sql, ) from ._connection import ( AzurePGConnectionPool, @@ -51,7 +54,10 @@ "VectorOpClass", "VectorType", # Base classes + "AsyncBaseAzurePGVectorStore", "BaseAzurePGVectorStore", + "metadata_filters_to_sql", + "_table_row_to_node", # Synchronous connection constructs "AzurePGConnectionPool", "ConnectionInfo", diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/_base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/_base.py index 9a30a9929a..32f98888d0 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/_base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/llama_index/vector_stores/azure_postgres/common/_base.py @@ -4,14 +4,25 @@ import re import sys from collections.abc import Sequence -from typing import Any +from typing import Any, Optional, Union import numpy as np -from pgvector.psycopg import register_vector # type: ignore[import-untyped] +from pgvector.psycopg import ( # type: ignore[import-untyped] + register_vector, + register_vector_async, +) from psycopg import sql from psycopg.rows import dict_row from pydantic import BaseModel, ConfigDict, PositiveInt, model_validator +from llama_index.core.schema import BaseNode +from llama_index.core.vector_stores.types import ( + FilterOperator, + MetadataFilter, + MetadataFilters, +) +from llama_index.core.vector_stores.utils import metadata_dict_to_node + from ._connection import AzurePGConnectionPool from ._shared import ( HNSW, @@ -20,7 +31,9 @@ IVFFlat, VectorOpClass, VectorType, + run_coroutine_in_sync, ) +from .aio._connection import AsyncAzurePGConnectionPool if sys.version_info < (3, 11): from typing_extensions import Self @@ -30,6 +43,134 @@ _logger = logging.getLogger(__name__) +def metadata_filters_to_sql(filters: Optional[MetadataFilters]) -> sql.SQL: + """Convert LlamaIndex MetadataFilters to a SQL WHERE clause. + + Args: + filters: Optional MetadataFilters object. + + Returns: + sql.SQL: SQL WHERE clause representing the filters. + """ + if not filters or not filters.filters: + return sql.SQL("TRUE") + + def _filter_to_sql(filter_item: Union[MetadataFilter, MetadataFilters]) -> sql.SQL: + """Recursively convert MetadataFilter or MetadataFilters to SQL.""" + if isinstance(filter_item, MetadataFilters): + # Handle nested MetadataFilters + if not filter_item.filters: + return sql.SQL("TRUE") + + filter_sqls = [_filter_to_sql(f) for f in filter_item.filters] + if filter_item.condition.lower() == "and": + return sql.SQL("({})").format(sql.SQL(" AND ").join(filter_sqls)) + elif filter_item.condition.lower() == "or": + return sql.SQL("({})").format(sql.SQL(" OR ").join(filter_sqls)) + else: # NOT + if len(filter_sqls) == 1: + return sql.SQL("NOT ({})").format(filter_sqls[0]) + else: + # For multiple filters with NOT, apply NOT to the AND of all filters + return sql.SQL("NOT ({})").format( + sql.SQL(" AND ").join(filter_sqls) + ) + + elif isinstance(filter_item, MetadataFilter): + # Handle individual MetadataFilter + key = filter_item.key + value = filter_item.value + operator = filter_item.operator + + # Use JSONB path for metadata column + column_ref = sql.SQL("metadata ->> {}").format(sql.Literal(key)) + + if operator == FilterOperator.EQ: + return sql.SQL("{} = {}").format(column_ref, sql.Literal(str(value))) + elif operator == FilterOperator.NE: + return sql.SQL("{} != {}").format(column_ref, sql.Literal(str(value))) + elif operator == FilterOperator.GT: + return sql.SQL("({}) > {}").format(column_ref, sql.Literal(value)) + elif operator == FilterOperator.LT: + return sql.SQL("({}) < {}").format(column_ref, sql.Literal(value)) + elif operator == FilterOperator.GTE: + return sql.SQL("({}) >= {}").format(column_ref, sql.Literal(value)) + elif operator == FilterOperator.LTE: + return sql.SQL("({}) <= {}").format(column_ref, sql.Literal(value)) + elif operator == FilterOperator.IN: + if isinstance(value, list): + values = sql.SQL(", ").join([sql.Literal(str(v)) for v in value]) + return sql.SQL("{} IN ({})").format(column_ref, values) + else: + return sql.SQL("{} = {}").format( + column_ref, sql.Literal(str(value)) + ) + elif operator == FilterOperator.NIN: + if isinstance(value, list): + values = sql.SQL(", ").join([sql.Literal(str(v)) for v in value]) + return sql.SQL("{} NOT IN ({})").format(column_ref, values) + else: + return sql.SQL("{} != {}").format( + column_ref, sql.Literal(str(value)) + ) + elif operator == FilterOperator.CONTAINS: + # For JSONB array contains + return sql.SQL("metadata -> {} ? {}").format( + sql.Literal(key), sql.Literal(str(value)) + ) + elif operator == FilterOperator.TEXT_MATCH: + return sql.SQL("{} LIKE {}").format( + column_ref, sql.Literal(f"%{value}%") + ) + elif operator == FilterOperator.TEXT_MATCH_INSENSITIVE: + return sql.SQL("{} ILIKE {}").format( + column_ref, sql.Literal(f"%{value}%") + ) + elif operator == FilterOperator.IS_EMPTY: + return sql.SQL("({} IS NULL OR {} = '')").format(column_ref, column_ref) + else: + # Default to equality for unsupported operators + return sql.SQL("{} = {}").format(column_ref, sql.Literal(str(value))) + + return sql.SQL("TRUE") + + filter_sqls = [_filter_to_sql(f) for f in filters.filters] + + if filters.condition.lower() == "and": + return sql.SQL(" AND ").join(filter_sqls) + elif filters.condition.lower() == "or": + return sql.SQL(" OR ").join(filter_sqls) + else: # NOT + if len(filter_sqls) == 1: + return sql.SQL("NOT ({})").format(filter_sqls[0]) + else: + return sql.SQL("NOT ({})").format(sql.SQL(" AND ").join(filter_sqls)) + + +def _table_row_to_node(row: dict[str, Any]) -> BaseNode: + """Convert a table row dictionary to a BaseNode object.""" + metadata = row.get("metadata") + if metadata is None: + raise ValueError("Metadata not found in row data.") + + node = metadata_dict_to_node(metadata, text=row.get("content")) + # Convert UUID to string if needed + node_id = row.get("id") + if node_id is not None: + node.node_id = str(node_id) + embedding = row.get("embedding") + + if isinstance(embedding, str): + embedding = row.get("embedding").strip("[]").split(",") + node.embedding = list(map(float, embedding)) + elif embedding is not None: + node.embedding = embedding + else: + raise ValueError("Missing embedding value") + + return node + + class BaseAzurePGVectorStore(BaseModel): """Base Pydantic model for an Azure PostgreSQL-backed vector store. @@ -714,3 +855,696 @@ def _get_by_ids(self, ids: Sequence[str], /) -> list[dict[str, Any]]: for result in resultset ] return documents + + +class AsyncBaseAzurePGVectorStore(BaseModel): + connection_pool: AsyncAzurePGConnectionPool + schema_name: str = "public" + table_name: str = "vector_store" + id_column: str = "id" + content_column: str = "content" + embedding_column: str = "embedding" + embedding_type: VectorType | None = None + embedding_dimension: PositiveInt | None = None + embedding_index: Algorithm | None = None + metadata_column: str | None = "metadata" + + model_config = ConfigDict( + arbitrary_types_allowed=True, # Allow arbitrary types like Embeddings and AzurePGConnectionPool + ) + + @model_validator(mode="after") + def verify_and_init_store(self) -> Self: + # verify that metadata_columns is not empty if provided + if self.metadata_columns is not None and len(self.metadata_columns) == 0: + raise ValueError("'metadata_columns' cannot be empty if provided.") + + _logger.debug( + "checking if table '%s.%s' exists with the required columns", + self.schema_name, + self.table_name, + ) + + coroutine = self._ensure_table_verified() + run_coroutine_in_sync(coroutine) + + return self + + async def _ensure_table_verified(self) -> None: + async with ( + self.connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + await cursor.execute( + sql.SQL( + """ + select a.attname as column_name, + format_type(a.atttypid, a.atttypmod) as column_type + from pg_attribute a + join pg_class c on a.attrelid = c.oid + join pg_namespace n on c.relnamespace = n.oid + where a.attnum > 0 + and not a.attisdropped + and n.nspname = %(schema_name)s + and c.relname = %(table_name)s + order by a.attnum asc + """ + ), + {"schema_name": self.schema_name, "table_name": self.table_name}, + ) + resultset = await cursor.fetchall() + existing_columns: dict[str, str] = { + row["column_name"]: row["column_type"] for row in resultset + } + + # if table exists, verify that required columns exist and have correct types + if len(existing_columns) > 0: + _logger.debug( + "table '%s.%s' exists with the following column mapping: %s", + self.schema_name, + self.table_name, + existing_columns, + ) + + id_column_type = existing_columns.get(self.id_column) + if id_column_type != "uuid": + raise ValueError( + f"Table '{self.schema_name}.{self.table_name}' must have a column '{self.id_column}' of type 'uuid'." + ) + + content_column_type = existing_columns.get(self.content_column) + if content_column_type is None or ( + content_column_type != "text" + and not content_column_type.startswith("varchar") + ): + raise ValueError( + f"Table '{self.schema_name}.{self.table_name}' must have a column '{self.content_column}' of type 'text' or 'varchar'." + ) + + embedding_column_type = existing_columns.get(self.embedding_column) + pattern = re.compile(r"(?P\w+)(?:\((?P\d+)\))?") + m = pattern.match(embedding_column_type if embedding_column_type else "") + parsed_type: str | None = m.group("type") if m else None + parsed_dim: PositiveInt | None = ( + PositiveInt(m.group("dim")) if m and m.group("dim") else None + ) + + vector_types = [t.value for t in VectorType.__members__.values()] + if parsed_type not in vector_types: + raise ValueError( + f"Column '{self.embedding_column}' in table '{self.schema_name}.{self.table_name}' must be one of the following types: {vector_types}." + ) + elif ( + self.embedding_type is not None + and parsed_type != self.embedding_type.value + ): + raise ValueError( + f"Column '{self.embedding_column}' in table '{self.schema_name}.{self.table_name}' has type '{parsed_type}', but the specified embedding_type is '{self.embedding_type.value}'. They must match." + ) + elif self.embedding_type is None: + _logger.info( + "embedding_type is not specified, but the column '%s' in table '%s.%s' has type '%s'. Overriding embedding_type accordingly.", + self.embedding_column, + self.schema_name, + self.table_name, + parsed_type, + ) + self.embedding_type = VectorType(parsed_type) + + if parsed_dim is not None and self.embedding_dimension is None: + _logger.info( + "embedding_dimension is not specified, but the column '%s' in table '%s.%s' has a dimension of %d. Overriding embedding_dimension accordingly.", + self.embedding_column, + self.schema_name, + self.table_name, + parsed_dim, + ) + self.embedding_dimension = parsed_dim + elif ( + parsed_dim is not None + and self.embedding_dimension is not None + and parsed_dim != self.embedding_dimension + ): + raise ValueError( + f"Column '{self.embedding_column}' in table '{self.schema_name}.{self.table_name}' has a dimension of {parsed_dim}, but the specified embedding_dimension is {self.embedding_dimension}. They must match." + ) + + if self.metadata_column is not None: + existing_type = existing_columns.get(self.metadata_column) + if existing_type is None: + raise ValueError( + f"Column '{self.metadata_column}' does not exist in table '{self.schema_name}.{self.table_name}'." + ) + + async with ( + self.connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + _logger.debug( + "checking if table '%s.%s' has a vector index on column '%s'", + self.schema_name, + self.table_name, + self.embedding_column, + ) + await cursor.execute( + sql.SQL( + """ + with cte as ( + select n.nspname as schema_name, + ct.relname as table_name, + ci.relname as index_name, + a.amname as index_type, + pg_get_indexdef( + ci.oid, -- index OID + generate_series(1, array_length(ii.indkey, 1)), -- column no + true -- pretty print + ) as index_column, + o.opcname as index_opclass, + ci.reloptions as index_opts + from pg_class ci + join pg_index ii on ii.indexrelid = ci.oid + join pg_am a on a.oid = ci.relam + join pg_class ct on ct.oid = ii.indrelid + join pg_namespace n on n.oid = ci.relnamespace + join pg_opclass o on o.oid = any(ii.indclass) + where ci.relkind = 'i' + and ct.relkind = 'r' + and ii.indisvalid + and ii.indisready + ) select schema_name, table_name, index_name, index_type, + index_column, index_opclass, index_opts + from cte + where schema_name = %(schema_name)s + and table_name = %(table_name)s + and index_column like %(embedding_column)s + and ( + index_opclass like '%%vector%%' + or index_opclass like '%%halfvec%%' + or index_opclass like '%%sparsevec%%' + or index_opclass like '%%bit%%' + ) + order by schema_name, table_name, index_name + """ + ), + { + "schema_name": self.schema_name, + "table_name": self.table_name, + "embedding_column": f"%{self.embedding_column}%", + }, + ) + resultset = await cursor.fetchall() + + if len(resultset) > 0: + _logger.debug( + "table '%s.%s' has %d vector index(es): %s", + self.schema_name, + self.table_name, + len(resultset), + resultset, + ) + + if self.embedding_index is None: + _logger.info( + "embedding_index is not specified, using the first found index: %s", + resultset[0], + ) + + index_type = resultset[0]["index_type"] + index_opclass = VectorOpClass(resultset[0]["index_opclass"]) + index_opts = { + opts.split("=")[0]: opts.split("=")[1] + for opts in resultset[0]["index_opts"] + } + + index = ( + DiskANN(op_class=index_opclass, **index_opts) + if index_type == "diskann" + else HNSW(op_class=index_opclass, **index_opts) + if index_type == "hnsw" + else IVFFlat(op_class=index_opclass, **index_opts) + ) + + self.embedding_index = index + else: + _logger.info( + "embedding_index is specified as '%s'; will try to find a matching index.", + self.embedding_index, + ) + + index_opclass = self.embedding_index.op_class.value # type: ignore[assignment] + if isinstance(self.embedding_index, DiskANN): + index_type = "diskann" + elif isinstance(self.embedding_index, HNSW): + index_type = "hnsw" + else: + index_type = "ivfflat" + + for row in resultset: + if ( + row["index_type"] == index_type + and row["index_opclass"] == index_opclass + ): + _logger.info( + "found a matching index: %s. overriding embedding_index.", + row, + ) + index_opts = { + opts.split("=")[0]: opts.split("=")[1] + for opts in row["index_opts"] + } + index = ( + DiskANN(op_class=index_opclass, **index_opts) + if index_type == "diskann" + else HNSW(op_class=index_opclass, **index_opts) + if index_type == "hnsw" + else IVFFlat(op_class=index_opclass, **index_opts) + ) + self.embedding_index = index + break + elif self.embedding_index is None: + _logger.info( + "embedding_index is not specified, and no vector index found in table '%s.%s'. defaulting to 'DiskANN' with 'vector_cosine_ops' opclass.", + self.schema_name, + self.table_name, + ) + self.embedding_index = DiskANN(op_class=VectorOpClass.vector_cosine_ops) + + # if table does not exist, create it + else: + _logger.debug( + "table '%s.%s' does not exist, creating it with the required columns", + self.schema_name, + self.table_name, + ) + + metadata_columns: list[tuple[str, str]] = [] # type: ignore[no-redef] + if self.metadata_columns is None: + _logger.warning( + "Metadata columns are not specified, defaulting to 'metadata' of type 'jsonb'." + ) + metadata_columns = [("metadata", "jsonb")] + elif isinstance(self.metadata_columns, str): + _logger.warning( + "Metadata columns are specified as a string, defaulting to 'jsonb' type." + ) + metadata_columns = [(self.metadata_columns, "jsonb")] + elif isinstance(self.metadata_columns, list): + _logger.warning( + "Metadata columns are specified as a list; defaulting to 'text' when type is not defined." + ) + metadata_columns = [ + (col[0], col[1]) if isinstance(col, tuple) else (col, "text") + for col in self.metadata_columns + ] + + if self.embedding_type is None: + _logger.warning( + "Embedding type is not specified, defaulting to 'vector'." + ) + self.embedding_type = VectorType.vector + + if self.embedding_dimension is None: + _logger.warning( + "Embedding dimension is not specified, defaulting to 1536." + ) + self.embedding_dimension = PositiveInt(1_536) + + if self.embedding_index is None: + _logger.warning( + "Embedding index is not specified, defaulting to 'DiskANN' with 'vector_cosine_ops' opclass." + ) + self.embedding_index = DiskANN(op_class=VectorOpClass.vector_cosine_ops) + + async with ( + self.connection_pool.connection() as conn, + conn.cursor() as cursor, + ): + await cursor.execute( + sql.SQL( + """ + create table {table_name} ( + {id_column} uuid primary key, + {content_column} text, + {embedding_column} {embedding_type}({embedding_dimension}), + {metadata_columns} + ) + """ + ).format( + table_name=sql.Identifier(self.schema_name, self.table_name), + id_column=sql.Identifier(self.id_column), + content_column=sql.Identifier(self.content_column), + embedding_column=sql.Identifier(self.embedding_column), + embedding_type=sql.Identifier(self.embedding_type.value), + embedding_dimension=sql.Literal(self.embedding_dimension), + metadata_columns=sql.SQL(", ").join( + sql.SQL("{col} {type}").format( + col=sql.Identifier(col), + type=sql.SQL(type), # type: ignore[arg-type] + ) + for col, type in metadata_columns + ), + ) + ) + + async def _delete_rows_from_table( + self, ids: list[str] | None = None, **kwargs: Any + ) -> bool | None: + """Delete rows from the table by their IDs or truncate the table. + + Args: + ids (list[str] | None): List of IDs to delete. If None, truncates the table. + **kwargs: Additional options, such as 'restart' and 'cascade' for truncation. + + Returns: + bool | None: True if successful, False if an exception occurred, None otherwise. + """ + async with self.connection_pool.connection() as conn: + try: + async with conn.transaction() as _tx, conn.cursor() as cursor: + if ids is None: + restart = bool(kwargs.pop("restart", None)) + cascade = bool(kwargs.pop("cascade", None)) + await cursor.execute( + sql.SQL( + """ + truncate table {table_name} {restart} {cascade} + """ + ).format( + table_name=sql.Identifier( + self.schema_name, self.table_name + ), + restart=sql.SQL( + "restart identity" + if restart + else "continue identity" + ), + cascade=sql.SQL("cascade" if cascade else "restrict"), + ) + ) + else: + await cursor.execute( + sql.SQL( + """ + delete from {table_name} + where {id_column} = any(%(id)s) + """ + ).format( + table_name=sql.Identifier( + self.schema_name, self.table_name + ), + id_column=sql.Identifier(self.id_column), + ), + {"id": ids}, + ) + except Exception: + return False + else: + return True + + async def _similarity_search_by_vector_with_distance( + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[tuple[dict, float, np.ndarray | None]]: + """Perform a similarity search using a vector embedding and return results with distances. + + Args: + embedding (list[float]): The query embedding vector. + k (int): Number of top results to return. + **kwargs: Additional options such as 'return_embeddings', 'top_m', and 'filter_expression'. + + Returns: + list[tuple[dict, float, np.ndarray | None]]: List of tuples containing document dict, distance, and optionally the embedding. + """ + assert self.embedding_index is not None, ( + "embedding_index should have already been set" + ) + return_embeddings = bool(kwargs.pop("return_embeddings", None)) + top_m = int(kwargs.pop("top_m", 5 * k)) + filter_expression: sql.SQL = kwargs.pop("filter_expression", sql.SQL("true")) + async with self.connection_pool.connection() as conn: + await register_vector_async(conn) + async with conn.cursor(row_factory=dict_row) as cursor: + metadata_column: list[str] + if isinstance(self.metadata_column, list): + metadata_column = [ + col if isinstance(col, str) else col[0] + for col in self.metadata_column + ] + elif isinstance(self.metadata_column, str): + metadata_column = [self.metadata_column] + else: + metadata_column = [] + + # do reranking for the following cases: + # - binary or scalar quantizations (for HNSW and IVFFlat), or + # - product quantization (for DiskANN) + if ( + self.embedding_index.op_class == VectorOpClass.bit_hamming_ops + or self.embedding_index.op_class == VectorOpClass.bit_jaccard_ops + or self.embedding_index.op_class == VectorOpClass.halfvec_cosine_ops + or self.embedding_index.op_class == VectorOpClass.halfvec_ip_ops + or self.embedding_index.op_class == VectorOpClass.halfvec_l1_ops + or self.embedding_index.op_class == VectorOpClass.halfvec_l2_ops + or ( + isinstance(self.embedding_index, DiskANN) + and self.embedding_index.product_quantized + ) + ): + sql_query = sql.SQL( + """ + select {outer_columns}, + {embedding_column} {op} %(query)s as distance, + {maybe_embedding_column} + from ( + select {inner_columns} + from {table_name} + where {filter_expression} + order by {expression} asc + limit %(top_m)s + ) i + order by {embedding_column} {op} %(query)s asc + limit %(top_k)s + """ + ).format( + outer_columns=sql.SQL(", ").join( + map( + sql.Identifier, + [ + self.id_column, + self.content_column, + *metadata_column, + ], + ) + ), + embedding_column=sql.Identifier(self.embedding_column), + op=( + sql.SQL( + VectorOpClass.vector_cosine_ops.to_operator() + ) # TODO(arda): Think of getting this from outside + if ( + self.embedding_index.op_class + in ( + VectorOpClass.bit_hamming_ops, + VectorOpClass.bit_jaccard_ops, + ) + ) + else sql.SQL(self.embedding_index.op_class.to_operator()) + ), + maybe_embedding_column=( + sql.Identifier(self.embedding_column) + if return_embeddings + else sql.SQL(" as ").join( + (sql.NULL, sql.Identifier(self.embedding_column)) + ) + ), + inner_columns=sql.SQL(", ").join( + map( + sql.Identifier, + [ + self.id_column, + self.content_column, + self.embedding_column, + *metadata_column, + ], + ) + ), + table_name=sql.Identifier(self.schema_name, self.table_name), + filter_expression=filter_expression, + expression=( + sql.SQL( + "binary_quantize({embedding_column})::bit({embedding_dim}) {op} binary_quantize({query})" + ).format( + embedding_column=sql.Identifier(self.embedding_column), + embedding_dim=sql.Literal(self.embedding_dimension), + op=sql.SQL(self.embedding_index.op_class.to_operator()), + query=sql.Placeholder("query"), + ) + if self.embedding_index.op_class + in ( + VectorOpClass.bit_hamming_ops, + VectorOpClass.bit_jaccard_ops, + ) + else ( + sql.SQL( + "{embedding_column}::halfvec({embedding_dim}) {op} {query}::halfvec({embedding_dim})" + ).format( + embedding_column=sql.Identifier( + self.embedding_column + ), + embedding_dim=sql.Literal(self.embedding_dimension), + op=sql.SQL( + self.embedding_index.op_class.to_operator() + ), + query=sql.Placeholder("query"), + ) + if self.embedding_index.op_class + in ( + VectorOpClass.halfvec_cosine_ops, + VectorOpClass.halfvec_ip_ops, + VectorOpClass.halfvec_l1_ops, + VectorOpClass.halfvec_l2_ops, + ) + else sql.SQL("{embedding_column} {op} {query}").format( + embedding_column=sql.Identifier( + self.embedding_column + ), + op=sql.SQL( + self.embedding_index.op_class.to_operator() + ), + query=sql.Placeholder("query"), + ) + ) + ), + ) + # otherwise (i.e., no quantization), do not do reranking + else: + sql_query = sql.SQL( + """ + select {outer_columns}, + {embedding_column} {op} %(query)s as distance, + {maybe_embedding_column} + from {table_name} + where {filter_expression} + order by {embedding_column} {op} %(query)s asc + limit %(top_k)s + """ + ).format( + outer_columns=sql.SQL(", ").join( + map( + sql.Identifier, + [ + self.id_column, + self.content_column, + *metadata_column, + ], + ) + ), + embedding_column=sql.Identifier(self.embedding_column), + op=sql.SQL(self.embedding_index.op_class.to_operator()), + maybe_embedding_column=( + sql.Identifier(self.embedding_column) + if return_embeddings + else sql.SQL(" as ").join( + (sql.NULL, sql.Identifier(self.embedding_column)) + ) + ), + table_name=sql.Identifier(self.schema_name, self.table_name), + filter_expression=filter_expression, + ) + await cursor.execute( + sql_query, + { + "query": np.array(embedding, dtype=np.float32), + "top_m": top_m, + "top_k": k, + }, + ) + resultset = await cursor.fetchall() + return [ + ( + { + "id": result[self.id_column], + "content": result[self.content_column], + "metadata": ( + result[metadata_column[0]] + if isinstance(self.metadata_column, str) + else {col: result[col] for col in metadata_column} + ), + }, + result["distance"], + result.get(self.embedding_column), # type: ignore[return-value] + ) + for result in resultset + ] + + async def _get_by_ids(self, ids: Sequence[str], /) -> list[dict[str, Any]]: + """Retrieve documents from the table by their IDs. + + Args: + ids (Sequence[str]): List of IDs to retrieve. + + Returns: + list[dict[str, Any]]: List of document dictionaries with id, content, embedding, and metadata. + """ + async with ( + self.connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + metadata_column: list[str] + if isinstance(self.metadata_column, list): + metadata_column = [ + col if isinstance(col, str) else col[0] + for col in self.metadata_column + ] + elif isinstance(self.metadata_column, str): + metadata_column = [self.metadata_column] + else: + metadata_column = [] + + if ids is not None: + where_clause = sql.SQL(" where {id_column} = any(%(id)s)").format( + id_column=sql.Identifier(self.id_column) + ) + else: + where_clause = sql.SQL("") + + get_sql = sql.SQL( + """ + select {columns} + from {table_name} + {where_clause} + """ + ).format( + columns=sql.SQL(", ").join( + map( + sql.Identifier, + [ + self.id_column, + self.content_column, + self.embedding_column, + *metadata_column, + ], + ) + ), + table_name=sql.Identifier(self.schema_name, self.table_name), + where_clause=where_clause, + ) + + if ids is not None: + await cursor.execute(get_sql, {"id": ids}) + else: + await cursor.execute(get_sql) + resultset = await cursor.fetchall() + documents = [ + { + "id": result[self.id_column], + "content": result[self.content_column], + "embedding": result[self.embedding_column], + "metadata": ( + result[metadata_column[0]] + if isinstance(self.metadata_column, str) + else {col: result[col] for col in metadata_column} + ), + } + for result in resultset + ] + return documents diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/common/test_connection.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/common/test_connection.py index cb8f9a48db..c5865b3d8f 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/common/test_connection.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/common/test_connection.py @@ -1,11 +1,11 @@ -"""Synchronous connection handling tests for Azure Database for PostgreSQL.""" +"""Connection handling tests for Azure Database for PostgreSQL.""" -from collections.abc import Generator -from contextlib import contextmanager, nullcontext +from collections.abc import AsyncGenerator, Generator +from contextlib import asynccontextmanager, contextmanager, nullcontext from typing import Any import pytest -from psycopg import Connection, sql +from psycopg import AsyncConnection, Connection, sql from pydantic import BaseModel, ConfigDict from llama_index.vector_stores.azure_postgres.common import ( @@ -13,6 +13,10 @@ check_connection, create_extensions, ) +from llama_index.vector_stores.azure_postgres.common.aio import ( + async_check_connection, + async_create_extensions, +) class MockCursorBase(BaseModel): @@ -56,6 +60,23 @@ def fetchone(self) -> None | dict: return self.response +class AsyncMockCursor(MockCursorBase): + """A mock cursor for async tests.""" + + async def execute(self, query: str | sql.SQL, _params=None) -> None: + """Execute a SQL query and record it for later inspection.""" + self.last_query = query + + async def fetchone(self) -> None | dict: + """Return a single-row result dict.""" + assert self.last_query is not None, "No query executed." + + if isinstance(self.last_query, str): + return None if self.broken else {"?column?": 1} + + return self.response + + @pytest.fixture def mock_cursor( connection: Connection, @@ -78,6 +99,24 @@ def mock_cursor(**_kwargs): monkeypatch.setattr(connection, "cursor", mock_cursor) +@pytest.fixture +async def async_mock_cursor( + async_connection: AsyncConnection, + monkeypatch: pytest.MonkeyPatch, + request: pytest.FixtureRequest, +): + """Fixture to mock `async_connection` to return `AsyncMockCursor` as a cursor.""" + assert isinstance(request.param, AsyncMockCursor), ( + "Expected an AsyncMockCursor instance." + ) + + @asynccontextmanager + async def async_mock_cursor(**_kwargs): + yield request.param + + monkeypatch.setattr(async_connection, "cursor", async_mock_cursor) + + class TestCheckConnection: """Tests for verifying the database connection and required extensions. @@ -86,10 +125,94 @@ class TestCheckConnection: extensions, version mismatches, and broken cursors. """ + async def test_async_it_works(self, async_connection: AsyncConnection) -> None: + """Ensure ``async_check_connection`` returns None on a healthy connection.""" + assert await async_check_connection(async_connection) is None + def test_it_works(self, connection: Connection) -> None: """Ensure ``check_connection`` returns None on a healthy connection.""" assert check_connection(connection) is None + @pytest.mark.parametrize( + ["extension", "async_mock_cursor", "expected_result"], + [ + ( + Extension(ext_name="test_ext", ext_version="1.0", schema_name="public"), + AsyncMockCursor( + broken=False, + response={ + "ext_name": "test_ext", + "ext_version": "1.0", + "schema_name": "public", + }, + ), + nullcontext(None), + ), + ( + Extension(ext_name="test_ext", ext_version="1.0", schema_name="public"), + AsyncMockCursor(broken=True, response=None), + pytest.raises(AssertionError, match="Connection check failed"), + ), + ( + Extension(ext_name="test_ext", ext_version="1.0", schema_name="public"), + AsyncMockCursor(broken=False, response=None), + pytest.raises( + RuntimeError, + match="Required extension 'test_ext' is not installed.", + ), + ), + ( + Extension(ext_name="test_ext", ext_version="1.0", schema_name="public"), + AsyncMockCursor( + broken=False, response={"ext_version": "wrong_version"} + ), + pytest.raises( + RuntimeError, + match="Required extension 'test_ext' version mismatch: expected 1.0, got wrong_version.", + ), + ), + ( + Extension(ext_name="test_ext", ext_version="1.0", schema_name="public"), + AsyncMockCursor( + broken=False, + response={"ext_version": "1.0", "schema_name": "wrong_schema"}, + ), + pytest.raises( + RuntimeError, + match="Required extension 'test_ext' is not installed in the expected schema: expected public, got wrong_schema.", + ), + ), + ], + ids=[ + "extension-installed", + "broken-cursor", + "extension-not-installed", + "version-mismatch", + "schema-mismatch", + ], + indirect=["async_mock_cursor"], + ) + async def test_async_mock_it_works( + self, + async_connection: AsyncConnection, + extension: Extension, + async_mock_cursor, + expected_result: nullcontext | pytest.RaisesExc, + ) -> None: + """Run parameterized checks of ``check_connection`` using mocked cursors. + + Parameterization covers installed extension, broken cursor, + missing extension, version mismatch, and schema mismatch cases. + """ + with expected_result as e: + assert ( + await async_check_connection( + async_connection, + required_extensions=[extension], + ) + == e + ) + @pytest.mark.parametrize( ["extension", "mock_cursor", "expected_result"], [ @@ -163,6 +286,78 @@ def test_mock_it_works( assert check_connection(connection, required_extensions=[extension]) == e +@pytest.fixture( + params=[Extension(ext_name="vector")], + ids=["vector"], +) +async def async_extension_creatable( + async_connection: AsyncConnection, request: pytest.FixtureRequest +) -> AsyncGenerator[Extension, Any]: + """Fixture to check if an extension can be created.""" + assert isinstance(request.param, Extension), "Expected an Extension instance." + + ext_already_installed = False + + async with async_connection.cursor() as cursor: + await cursor.execute( + sql.SQL( + """ + select extname, extversion + from pg_extension + where extname = %(ext_name)s + """ + ), + {"ext_name": request.param.ext_name}, + ) + result = await cursor.fetchone() + ext_already_installed = result is not None + + try: + await cursor.execute( + sql.SQL( + """ + create extension if not exists {ext_name} + with {schema_expr} + {version_expr} + {cascade_expr} + """ + ).format( + ext_name=sql.Identifier(request.param.ext_name), + schema_expr=sql.SQL("schema {schema_name}").format( + schema_name=sql.Identifier(request.param.schema_name) + ) + if request.param.schema_name + else sql.SQL(""), + version_expr=sql.SQL("version {ext_version}").format( + ext_version=sql.Literal(request.param.ext_version) + ) + if request.param.ext_version + else sql.SQL(""), + cascade_expr=sql.SQL("cascade") + if request.param.cascade + else sql.SQL(""), + ) + ) + except Exception as e: + pytest.skip( + reason=f"Extension {request.param.ext_name} could not be created: {e}" + ) + + yield request.param + + if not ext_already_installed: + async with async_connection.cursor() as cursor: + await cursor.execute( + sql.SQL( + """ + drop extension if exists {ext_name} + """ + ).format( + ext_name=sql.Identifier(request.param.ext_name), + ) + ) + + @pytest.fixture def extension_creatable( connection: Connection, request: pytest.FixtureRequest @@ -245,6 +440,24 @@ class TestCreateExtensions: extension raises an informative exception. """ + @pytest.mark.parametrize( + "async_extension_creatable", + [Extension(ext_name="vector")], + ids=["vector"], + indirect=True, + ) + async def test_async_it_works( + self, async_connection: AsyncConnection, async_extension_creatable: Extension + ): + """Assert that creating a valid extension returns None (no error).""" + assert ( + await async_create_extensions( + async_connection, + required_extensions=[async_extension_creatable], + ) + is None + ) + @pytest.mark.parametrize( "extension_creatable", [Extension(ext_name="vector")], @@ -261,6 +474,21 @@ def test_it_works(self, connection: Connection, extension_creatable: Extension): is None ) + async def test_async_it_fails(self, async_connection: AsyncConnection): + """Verify that creating a missing extension raises an exception.""" + extension = Extension( + ext_name="non_existent_ext", + ext_version="1.0", + schema_name="public", + ) + with pytest.raises( + Exception, match='extension "non_existent_ext" is not available' + ): + await async_create_extensions( + async_connection, + required_extensions=[extension], + ) + def test_it_fails(self, connection: Connection): """Verify that creating a missing extension raises an exception.""" extension = Extension( diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/conftest.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/conftest.py index e6ffc6c97d..dd4f259568 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/conftest.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/conftest.py @@ -8,18 +8,21 @@ import logging import os -from collections.abc import Generator -from time import sleep +from collections.abc import AsyncGenerator, Generator from typing import Any import pytest from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential from azure.identity import DefaultAzureCredential -from psycopg import Connection, sql +from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential +from psycopg import AsyncConnection, Connection, sql from psycopg.rows import dict_row -from psycopg_pool import ConnectionPool, PoolTimeout +from psycopg_pool import AsyncConnectionPool, ConnectionPool, PoolTimeout from llama_index.vector_stores.azure_postgres.common import ( + AsyncAzurePGConnectionPool, + AsyncConnectionInfo, AzurePGConnectionPool, BasicAuth, ConnectionInfo, @@ -88,6 +91,177 @@ def pytest_addoption(parser: pytest.Parser) -> None: ) +@pytest.fixture +async def async_connection( + async_connection_pool: AsyncConnectionPool, +) -> AsyncGenerator[AsyncConnection, Any]: + """Fixture to provide an asynchronous PostgreSQL connection. + + :param async_connection_pool: The asynchronous connection pool (fixture) to use. + :type async_connection_pool: AsyncConnectionPool + :return: An asynchronous PostgreSQL connection. + :rtype: AsyncConnection + """ + global postgres_not_available + if postgres_not_available: + pytest.skip("Could not reach the database") + try: + async with async_connection_pool.connection() as conn: + yield conn + except PoolTimeout as exc: + logger.warning("PoolTimeout %s", exc) + postgres_not_available = True + pytest.skip("Could not reach the database") + + +@pytest.fixture(scope="session") +async def async_connection_info( + async_credentials: BasicAuth | AsyncTokenCredential, + pytestconfig: pytest.Config, +) -> AsyncConnectionInfo: + """Fixture to provide asynchronous connection information for PostgreSQL. + + :param async_credentials: The asynchronous credentials (fixture) to use for authentication. + :type async_credentials: BasicAuth | AsyncTokenCredential + :param pytestconfig: The pytest configuration object. + :type pytestconfig: pytest.Config + :return: An asynchronous connection information object. + :rtype: AsyncConnectionInfo + """ + return AsyncConnectionInfo( + application_name=pytestconfig.getoption("pg_appname"), + host=pytestconfig.getoption("pg_host"), + dbname=pytestconfig.getoption("pg_database"), + port=pytestconfig.getoption("pg_port"), + sslmode=SSLMode.prefer, + credentials=async_credentials, + ) + + +@pytest.fixture(scope="session") +async def async_connection_pool( + async_connection_info: AsyncConnectionInfo, +) -> AsyncGenerator[AsyncConnectionPool, Any]: + """Fixture to provide an asynchronous PostgreSQL connection pool. + + :param async_connection_info: The asynchronous connection information (fixture) to use. + :type async_connection_info: AsyncConnectionInfo + :return: An asynchronous PostgreSQL connection pool. + :rtype: AsyncConnectionPool + """ + + # disable prepared statements during testing + async def disable_prepared_statements(async_conn: AsyncConnection) -> None: + async_conn.prepare_threshold = None + + credentials, host = async_connection_info.credentials, async_connection_info.host + assert host is not None, "Host must be provided for connection pool" + if isinstance(credentials, AsyncTokenCredential) and host.find("azure.com") == -1: + pytest.skip( + reason="Azure AD authentication requires an Azure PostgreSQL instance" + ) + async with AsyncAzurePGConnectionPool( + azure_conn_info=async_connection_info, configure=disable_prepared_statements + ) as pool: + yield pool + + +@pytest.fixture(scope="session", params=["azure-ad", "basic-auth"]) +async def async_credentials( + pytestconfig: pytest.Config, request: pytest.FixtureRequest +) -> BasicAuth | AsyncTokenCredential: + """Fixture to provide asynchronous credentials for PostgreSQL. + + This fixture supports both Azure AD authentication ("azure-ad" in `request.param`) + and basic authentication ("basic-auth" in `request.param`). When/if Azure AD + authentication is requested, it uses the `AsyncDefaultAzureCredential` to obtain + a token. For basic authentication, it retrieves the username and password from + the pytest configuration options. + + When/if Azure AD authentication is not available, it skips the test with a reason. + + :param pytestconfig: The pytest configuration object. + :type pytestconfig: pytest.Config + :param request: The pytest fixture request object. + :type request: pytest.FixtureRequest + :raises ValueError: If the authentication type is unknown. + :return: The asynchronous credentials for PostgreSQL. + :rtype: BasicAuth | AsyncTokenCredential + """ + if request.param == "azure-ad": + try: + async_credentials = AsyncDefaultAzureCredential() + _token = await async_credentials.get_token(TOKEN_CREDENTIAL_SCOPE) + return async_credentials + except Exception: + pytest.skip(reason="Azure AD authentication not available") + elif request.param == "basic-auth": + username = pytestconfig.getoption("pg_user") + password = pytestconfig.getoption("pg_password") + return BasicAuth(username=username, password=password) + else: + raise ValueError(f"Unknown auth type: {request.param}") + + +@pytest.fixture(scope="session") +async def async_schema( + async_connection_pool: AsyncConnectionPool, +) -> AsyncGenerator[str, Any]: + """Fixture to create and drop a schema for testing purposes. + + :param async_connection_pool: The asynchronous connection pool (fixture) to use. + :type async_connection_pool: AsyncConnectionPool + :return: The name of the created schema. + :rtype: str + """ + global postgres_not_available + if postgres_not_available: + pytest.skip("Could not reach the database") + + try: + async with ( + async_connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + await cursor.execute( + sql.SQL( + """ + select oid as schema_id, nspname as schema_name + from pg_namespace + """ + ) + ) + resultset = await cursor.fetchall() + schema_names = [row["schema_name"] for row in resultset] + except PoolTimeout as exc: + logger.warning("PoolTimeout %s", exc) + postgres_not_available = True + pytest.skip("Could not reach the database") + + _schema: str | None = None + for idx in range(100_000): + _schema_name = f"pytest-{idx:05d}" + if _schema_name not in schema_names: + _schema = _schema_name + break + if _schema is None: + pytest.fail("Could not find a unique schema name for testing") + + async with async_connection_pool.connection() as conn, conn.cursor() as cursor: + await cursor.execute( + sql.SQL("create schema {schema}").format(schema=sql.Identifier(_schema)) + ) + + yield _schema + + async with async_connection_pool.connection() as conn, conn.cursor() as cursor: + await cursor.execute( + sql.SQL("drop schema {schema} cascade").format( + schema=sql.Identifier(_schema) + ) + ) + + @pytest.fixture def connection(connection_pool: ConnectionPool) -> Generator[Connection, Any, None]: """Fixture to provide a PostgreSQL connection. @@ -100,22 +274,13 @@ def connection(connection_pool: ConnectionPool) -> Generator[Connection, Any, No global postgres_not_available if postgres_not_available: pytest.skip("Could not reach the database") - max_retries = 3 - backoff = 0.5 - for attempt in range(1, max_retries + 1): - try: - with connection_pool.connection() as conn: - yield conn - except PoolTimeout as exc: - logger.warning( - "PoolTimeout on attempt %d/%d: %s", attempt, max_retries, exc - ) - if attempt == max_retries: - logger.error("Exhausted retries acquiring DB connection") - - postgres_not_available = True - pytest.skip("Could not reach the database") - sleep(backoff * attempt) + try: + with connection_pool.connection() as conn: + yield conn + except PoolTimeout as exc: + logger.warning("PoolTimeout %s", exc) + postgres_not_available = True + pytest.skip("Could not reach the database") @pytest.fixture(scope="session") @@ -219,33 +384,26 @@ def schema(connection_pool: ConnectionPool) -> Generator[str, Any, None]: global postgres_not_available if postgres_not_available: pytest.skip("Could not reach the database") - max_retries = 3 - backoff = 0.5 - for attempt in range(1, max_retries + 1): - try: - with ( - connection_pool.connection() as conn, - conn.cursor(row_factory=dict_row) as cursor, - ): - cursor.execute( - sql.SQL( - """ - select oid as schema_id, nspname as schema_name - from pg_namespace - """ - ) + + try: + with ( + connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + cursor.execute( + sql.SQL( + """ + select oid as schema_id, nspname as schema_name + from pg_namespace + """ ) - resultset = cursor.fetchall() - schema_names = [row["schema_name"] for row in resultset] - except PoolTimeout as exc: - logger.warning( - "PoolTimeout on attempt %d/%d: %s", attempt, max_retries, exc ) - if attempt == max_retries: - logger.error("Exhausted retries acquiring DB connection") - postgres_not_available = True - pytest.skip("Could not reach the database") - sleep(backoff * attempt) + resultset = cursor.fetchall() + schema_names = [row["schema_name"] for row in resultset] + except PoolTimeout as exc: + logger.warning("PoolTimeout %s", exc) + postgres_not_available = True + pytest.skip("Could not reach the database") _schema: str | None = None for idx in range(100_000): diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/conftest.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/conftest.py index 6af5403c51..d6b7ee205f 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/conftest.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/conftest.py @@ -1,11 +1,11 @@ """Pytest fixtures and Pydantic models used for Azure PostgreSQL vector store integration tests.""" -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator from typing import Any import pytest from psycopg import sql -from psycopg_pool import ConnectionPool +from psycopg_pool import AsyncConnectionPool, ConnectionPool from pydantic import BaseModel, PositiveInt from llama_index.core.schema import Node @@ -14,6 +14,7 @@ MetadataFilters, ) from llama_index.vector_stores.azure_postgres import ( + AsyncAzurePGVectorStore, AzurePGVectorStore, ) from llama_index.vector_stores.azure_postgres.common import ( @@ -23,7 +24,6 @@ ) _FIXTURE_PARAMS_TABLE: dict[str, Any] = { - "scope": "class", "params": [ { "existing": True, @@ -38,7 +38,6 @@ }, ], "ids": [ - # "non-existing-table-metadata-str", "existing-table-metadata-str", ], } @@ -108,6 +107,75 @@ class Table(BaseModel): metadata_column: str +@pytest.fixture(**_FIXTURE_PARAMS_TABLE) +async def async_table( + async_connection_pool: AsyncConnectionPool, + async_schema: str, + request: pytest.FixtureRequest, +) -> AsyncGenerator[Table, Any]: + """Fixture to provide a parametrized table configuration for asynchronous tests. + + This fixture yields a `Table` model with normalized metadata columns. When + the parameter `existing` is `True`, it creates the table in the provided + schema before yielding and drops it after the test class completes. + + :param async_connection_pool: The asynchronous connection pool to use for DDL. + :type async_connection_pool: AsyncConnectionPool + :param async_schema: The schema name where the table should be created. + :type async_schema: str + :param request: The pytest request object providing parametrization. + :type request: pytest.FixtureRequest + :return: An asynchronous generator yielding a `Table` configuration. + :rtype: AsyncGenerator[Table, Any] + """ + assert isinstance(request.param, dict), "Request param must be a dictionary" + + table = Table( + existing=request.param.get("existing", None), + schema_name=async_schema, + table_name=request.param.get("table_name", "llamaindex"), + id_column=request.param.get("id_column", "id"), + content_column=request.param.get("content_column", "content"), + embedding_column=request.param.get("embedding_column", "embedding"), + embedding_type=request.param.get("embedding_type", "vector"), + embedding_dimension=request.param.get("embedding_dimension", 1_536), + embedding_index=request.param.get("embedding_index", None), + metadata_column=request.param.get("metadata_column", "metadata"), + ) + + if table.existing: + async with async_connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute( + sql.SQL( + """ + create table {table_name} ( + {id_column} uuid primary key, + {content_column} text, + {embedding_column} {embedding_type}({embedding_dimension}), + {metadata_column} jsonb + ) + """ + ).format( + table_name=sql.Identifier(async_schema, table.table_name), + id_column=sql.Identifier(table.id_column), + content_column=sql.Identifier(table.content_column), + embedding_column=sql.Identifier(table.embedding_column), + embedding_type=sql.Identifier(table.embedding_type), + embedding_dimension=sql.Literal(table.embedding_dimension), + metadata_column=sql.Identifier(table.metadata_column), + ) + ) + + yield table + + async with async_connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute( + sql.SQL("drop table {table} cascade").format( + table=sql.Identifier(async_schema, table.table_name) + ) + ) + + @pytest.fixture(**_FIXTURE_PARAMS_TABLE) def table( connection_pool: ConnectionPool, @@ -219,6 +287,62 @@ def filters( return vsfilters +@pytest.fixture +async def async_vectorstore( + async_connection_pool: AsyncConnectionPool, async_table: Table +) -> AsyncAzurePGVectorStore: + """Define vectorstore with DiskANN.""" + diskann = DiskANN( + op_class="vector_cosine_ops", max_neighbors=32, l_value_ib=100, l_value_is=100 + ) + print(async_table) + vector_store = AsyncAzurePGVectorStore.from_params( + connection_pool=async_connection_pool, + schema_name=async_table.schema_name, + table_name=async_table.table_name, + embed_dim=async_table.embedding_dimension, + embedding_index=diskann, + ) + + # add several documents with deterministic embeddings for testing similarity + dim = int(async_table.embedding_dimension) + + nodes = [] + + n1 = Node() + n1.node_id = "00000000-0000-0000-0000-000000000001" + n1.set_content("Text 1 about cats") + n1.embedding = [1.0] * dim + n1.metadata = {"metadata_column1": "text1", "metadata_column2": 1} + nodes.append(n1) + + n2 = Node() + n2.node_id = "00000000-0000-0000-0000-000000000002" + n2.set_content("Text 2 about tigers") + # tigers should be close to cats + n2.embedding = [0.95] * dim + n2.metadata = {"metadata_column1": "text2", "metadata_column2": 2} + nodes.append(n2) + + n3 = Node() + n3.node_id = "00000000-0000-0000-0000-000000000003" + n3.set_content("Text 3 about dogs") + n3.embedding = [0.3] * dim + n3.metadata = {"metadata_column1": "text3", "metadata_column2": 3} + nodes.append(n3) + + n4 = Node() + n4.node_id = "00000000-0000-0000-0000-000000000004" + n4.set_content("Text 4 about plants") + n4.embedding = [-1.0] * dim + n4.metadata = {"metadata_column1": "text4", "metadata_column2": 4} + nodes.append(n4) + + await vector_store.async_add(nodes) + + return vector_store + + @pytest.fixture def vectorstore(connection_pool: ConnectionPool, table: Table) -> AzurePGVectorStore: """Define vectorstore with DiskANN.""" diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/test_vectorstore.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/test_vectorstore.py index 5c22b77fef..bbfd4b3c51 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/test_vectorstore.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/tests/llama_index/test_vectorstore.py @@ -17,7 +17,10 @@ MetadataFilters, VectorStoreQuery, ) -from llama_index.vector_stores.azure_postgres import AzurePGVectorStore +from llama_index.vector_stores.azure_postgres import ( + AsyncAzurePGVectorStore, + AzurePGVectorStore, +) from llama_index.vector_stores.azure_postgres.common import DiskANN from .conftest import Table @@ -376,3 +379,275 @@ def test_query( assert all("plants" not in c for c in contents), ( f"Expected 'plants' not to be in retrieved documents' contents for query: {query}" ) + + +class TestAsyncAzurePGVectorStore: + """Async integration tests for AsyncAzurePGVectorStore implementation.""" + + @pytest.mark.asyncio + async def test_table_creation_success(self, async_vectorstore, async_table): + """Verify the database table is created with the correct columns and types (async).""" + async with ( + async_vectorstore.connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + await cursor.execute( + _GET_TABLE_COLUMNS_AND_TYPES, + { + "schema_name": async_table.schema_name, + "table_name": async_table.table_name, + }, + ) + resultset = await cursor.fetchall() + verify_table_created(async_table, resultset) + + @pytest.mark.asyncio + async def test_vectorstore_initialization_from_params( + self, + async_connection_pool, + async_schema: str, + ): + """Create a store using class factory `from_params` and assert type (async).""" + table_name = "vs_init_from_params_async" + embedding_dimension = 3 + + diskann = DiskANN( + op_class="vector_cosine_ops", + max_neighbors=32, + l_value_ib=100, + l_value_is=100, + ) + + vectorstore = AsyncAzurePGVectorStore.from_params( + connection_pool=async_connection_pool, + schema_name=async_schema, + table_name=table_name, + embed_dim=embedding_dimension, + embedding_index=diskann, + ) + assert isinstance(vectorstore, AsyncAzurePGVectorStore) + + @pytest.mark.asyncio + async def test_aget_nodes( + self, + async_vectorstore, + ): + """Retrieve all nodes and assert expected seeded node count (async).""" + in_nodes = await async_vectorstore.aget_nodes() + assert len(in_nodes) == 4, "Retrieved node count does not match expected" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ["node_tuple", "expected"], + [ + ("node-success", nullcontext(AsyncAzurePGVectorStore)), + ("node-not-found", pytest.raises(IndexError)), + ], + indirect=["node_tuple"], + ids=[ + "success", + "not-found", + ], + ) + async def test_aget_nodes_with_ids( + self, + async_vectorstore, + node_tuple, + expected, + ): + """Retrieve nodes by ID and validate returned node matches expected (async).""" + node, expected_node_id = node_tuple + in_nodes = await async_vectorstore.aget_nodes([node.node_id]) + with expected: + assert expected_node_id == in_nodes[0].node_id, ( + "Retrieved node ID does not match expected" + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ["node_tuple", "expected"], + [ + ("node-success", nullcontext(AsyncAzurePGVectorStore)), + ], + indirect=["node_tuple"], + ids=[ + "success", + ], + ) + async def test_async_add( + self, + async_vectorstore, + node_tuple, + expected, + ): + """Add a node to the store and assert the returned ID matches (async).""" + node, expected_node_id = node_tuple + with expected: + assert node.node_id is not None, "Node ID must be provided for this test" + returned_ids = await async_vectorstore.async_add([node]) + assert returned_ids[0] == expected_node_id, "Inserted text IDs do not match" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ["doc_id"], + [ + ("1",), + ("10",), + ], + ids=["existing", "non-existing"], + ) + async def test_adelete( + self, + async_vectorstore, + doc_id, + ): + """Delete a node by reference doc id and assert it was removed (async).""" + await async_vectorstore.adelete(doc_id) + async with ( + async_vectorstore.connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + await cursor.execute( + sql.SQL( + """ + select {metadata} ->> 'doc_id' as doc_id + from {table_name} + """ + ).format( + metadata=sql.Identifier(async_vectorstore.metadata_columns), + table_name=sql.Identifier( + async_vectorstore.schema_name, async_vectorstore.table_name + ), + ) + ) + resultset = await cursor.fetchall() + remaining_set = set(str(r["doc_id"]) for r in resultset) + assert doc_id not in remaining_set, ( + "Deleted document IDs should not exist in the remaining set" + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ["node_tuple"], + [ + ("node-success",), + ("node-not-found",), + ], + indirect=["node_tuple"], + ids=[ + "success", + "not-found", + ], + ) + async def test_adelete_nodes( + self, + async_vectorstore, + node_tuple, + ): + """Delete a list of node IDs and assert they are removed from the table (async).""" + node, expected_node_id = node_tuple + await async_vectorstore.adelete_nodes([node.node_id]) + async with ( + async_vectorstore.connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + await cursor.execute( + sql.SQL( + """ + select {id_column} as node_id + from {table_name} + """ + ).format( + id_column=sql.Identifier(async_vectorstore.id_column), + table_name=sql.Identifier( + async_vectorstore.schema_name, async_vectorstore.table_name + ), + ) + ) + resultset = await cursor.fetchall() + remaining_set = set(str(r["node_id"]) for r in resultset) + assert expected_node_id not in remaining_set, ( + "Deleted document IDs should not exist in the remaining set" + ) + + @pytest.mark.asyncio + async def test_aclear( + self, + async_vectorstore, + ): + """Clear all nodes from the underlying table and verify none remain (async).""" + await async_vectorstore.aclear() + async with ( + async_vectorstore.connection_pool.connection() as conn, + conn.cursor(row_factory=dict_row) as cursor, + ): + await cursor.execute( + sql.SQL( + """ + select {id_column} as node_id + from {table_name} + """ + ).format( + id_column=sql.Identifier(async_vectorstore.id_column), + table_name=sql.Identifier( + async_vectorstore.schema_name, async_vectorstore.table_name + ), + ) + ) + resultset = await cursor.fetchall() + remaining_set = set(str(r["node_id"]) for r in resultset) + assert not remaining_set, "All document IDs should have been deleted" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ["query", "embedding", "k", "filters"], + [ + ("query about cats", [0.99] * 1536, 2, None), + ("query about animals", [0.5] * 1536, 3, None), + ("query about cats", [0.99] * 1536, 2, "filter1"), + ("query about cats", [0.99] * 1536, 2, "filter2"), + ], + indirect=["filters"], + ids=[ + "search-cats", + "search-animals", + "search-cats-filtered", + "search-cats-multifiltered", + ], + ) + async def test_aquery( + self, + async_vectorstore, + query, + embedding, + k, + filters, + ): + """Run a similarity query and assert returned documents match expectations (async).""" + vsquery = VectorStoreQuery( + query_str=query, + query_embedding=embedding, + similarity_top_k=k, + filters=filters, + ) + results = await async_vectorstore.aquery(query=vsquery) + results = results.nodes + contents = [row.get_content() for row in results] + if ("cats" in query) or ("animals" in query): + assert len(results) == k, f"Expected {k} results" + assert any("cats" in c for c in contents) or any( + "tigers" in c for c in contents + ), ( + f"Expected 'cats' or 'tigers' in retrieved documents' contents for query: {query}" + ) + if "cats" in query: + assert all("dogs" not in c for c in contents), ( + f"Expected 'dogs' not to be in retrieved documents' contents for query: {query}" + ) + elif "animals" in query: + assert any("dogs" in c for c in contents), ( + f"Expected 'dogs' to be in retrieved documents' contents for query: {query}" + ) + assert all("plants" not in c for c in contents), ( + f"Expected 'plants' not to be in retrieved documents' contents for query: {query}" + ) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/uv.lock b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/uv.lock index abafd4ae21..02cf9583ed 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/uv.lock +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/uv.lock @@ -1889,7 +1889,7 @@ type = [ requires-dist = [ { name = "aiohttp", specifier = "~=3.0" }, { name = "azure-identity", specifier = "~=1.0" }, - { name = "llama-index-core", specifier = "~=0.13.0" }, + { name = "llama-index-core", specifier = ">=0.13,<0.15" }, { name = "numpy", specifier = "~=2.0" }, { name = "pgvector", specifier = "~=0.4.0" }, { name = "psycopg", extras = ["binary", "pool"], specifier = "~=3.0" }, From e8951cb2cd1600253e65006aeab60ceb2e7db036 Mon Sep 17 00:00:00 2001 From: Orhan Kislal Date: Fri, 3 Oct 2025 12:00:43 +0300 Subject: [PATCH 2/2] Update version number --- .../llama-index-vector-stores-azurepostgresql/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/pyproject.toml index f841abb88f..b6b7e218ae 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azurepostgresql/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [dependency-groups] dev = [ "jupyterlab~=4.0", - "llama-index~=0.13.0", + "llama-index-core>=0.13,<0.15", "openai~=1.0", ] lint = ["ruff~=0.12.0"] @@ -29,7 +29,7 @@ type = ["mypy~=1.0"] [project] name = "llama-index-vector-stores-azurepostgres" -version = "0.1.0" +version = "0.2.0" description = "AI framework integrations for Azure Database for PostgreSQL" readme = "README.md" license = {text = "MIT"}