diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py index 3b76cde59c..c3eda13e4d 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py @@ -799,12 +799,16 @@ def _query_with_score( stmt = self._build_query(embedding, limit, metadata_filters, **kwargs) with self._session() as session, session.begin(): from sqlalchemy import text + from psycopg2 import sql if kwargs.get("ivfflat_probes"): ivfflat_probes = kwargs.get("ivfflat_probes") session.execute( - text(f"SET ivfflat.probes = :ivfflat_probes"), - {"ivfflat_probes": ivfflat_probes}, + text( + sql.SQL("SET ivfflat.probes = {}") + .format(sql.Literal(ivfflat_probes)) + .as_string(context=self._engine.raw_connection().connection), + ) ) if self.hnsw_kwargs: hnsw_ef_search = ( @@ -843,6 +847,7 @@ async def _aquery_with_score( stmt = self._build_query(embedding, limit, metadata_filters, **kwargs) async with self._async_session() as async_session, async_session.begin(): from sqlalchemy import text + from psycopg2 import sql if self.hnsw_kwargs: hnsw_ef_search = ( @@ -854,8 +859,11 @@ async def _aquery_with_score( if kwargs.get("ivfflat_probes"): ivfflat_probes = kwargs.get("ivfflat_probes") await async_session.execute( - text(f"SET ivfflat.probes = :ivfflat_probes"), - {"ivfflat_probes": ivfflat_probes}, + text( + sql.SQL("SET ivfflat.probes = {}") + .format(sql.Literal(ivfflat_probes)) + .as_string(context=self._engine.raw_connection().connection), + ) ) res = await async_session.execute(stmt) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/pyproject.toml index d6827d0a10..bc94c3051d 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/pyproject.toml @@ -27,7 +27,7 @@ dev = [ [project] name = "llama-index-vector-stores-postgres" -version = "0.6.8" +version = "0.6.9" description = "llama-index vector_stores postgres integration" authors = [{name = "Your Name", email = "you@example.com"}] requires-python = ">=3.9,<4.0" diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/tests/test_postgres.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/tests/test_postgres.py index 1ec2f1c6ac..2f977ad345 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/tests/test_postgres.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/tests/test_postgres.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Generator, List, Union, Optional import pytest +from llama_index.core import VectorStoreIndex from llama_index.core.schema import ( BaseNode, IndexNode, @@ -9,6 +10,8 @@ RelatedNodeInfo, TextNode, ) +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.response_synthesizers import get_response_synthesizer from llama_index.core.vector_stores.types import ( ExactMatchFilter, FilterOperator, @@ -1283,7 +1286,35 @@ async def test_delete_nodes_metadata( assert all(i not in res.ids for i in ["bbb", "aaa", "ddd"]) assert "ccc" in res.ids + +@pytest.mark.skipif(postgres_not_available, reason="postgres db is not available") +@pytest.mark.asyncio +@pytest.mark.parametrize("pg_fixture", ["pg", "pg_hybrid"], indirect=True) +@pytest.mark.parametrize("use_async", [True, False]) +async def test_set_ivfflat(pg_fixture: PGVectorStore, use_async: bool) -> None: + vector_index = VectorStoreIndex.from_vector_store(vector_store=pg_fixture) + + vector_retriever = vector_index.as_retriever( + vector_store_kwargs={"ivfflat_probes": 20} + ) + vector_response_synthesizer = get_response_synthesizer( + use_async=use_async, + ) + + vector_query_engine = RetrieverQueryEngine( + retriever=vector_retriever, + response_synthesizer=vector_response_synthesizer, + ) + + query = "lorem ipsum" + if use_async: + response = await vector_query_engine.aquery(query) + else: + response = vector_query_engine.query(query) + + assert response + @pytest.mark.skipif(postgres_not_available, reason="postgres db is not available") @pytest.mark.asyncio @pytest.mark.parametrize("use_async", [True, False])