Skip to content

Commit 3174c9a

Browse files
committed
update pgvector to use pgvectorstore
1 parent 1cb544c commit 3174c9a

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _init_db_provider(db_type: str) -> DBProvider:
127127
elif db_type == "PGVECTOR":
128128
url = get("PGVECTOR_URL")
129129
collection = get("PGVECTOR_COLLECTION_NAME")
130-
return PGVectorProvider(embedding_model, url, collection)
130+
return PGVectorProvider(embedding_model, url, collection, embedding_length)
131131

132132
elif db_type == "MSSQL":
133133
connection_string = get("MSSQL_CONNECTION_STRING")

vector_db/pgvector_provider.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from urllib.parse import urlparse
44

55
from langchain_core.documents import Document
6-
from langchain_postgres import PGVector
6+
from langchain_postgres import PGEngine, PGVectorStore
77

88
from vector_db.db_provider import DBProvider
99

@@ -26,18 +26,26 @@ class PGVectorProvider(DBProvider):
2626
embedding_model (str): The model name to use for computing embeddings.
2727
url (str): PostgreSQL connection string (e.g. "postgresql://user:pass@host:5432/db").
2828
collection_name (str): Name of the table/collection used for storing vectors.
29+
embedding_length (int): Dimensionality of the embeddings (e.g., 768 for all-mpnet-base-v2).
2930
3031
Example:
3132
>>> from vector_db.pgvector_provider import PGVectorProvider
3233
>>> provider = PGVectorProvider(
3334
... embedding_model="BAAI/bge-base-en-v1.5",
3435
... url="postgresql://user:pass@localhost:5432/vector_db",
35-
... collection_name="rag_chunks"
36+
... collection_name="rag_chunks",
37+
... embedding_length=768
3638
... )
3739
>>> provider.add_documents(docs)
3840
"""
3941

40-
def __init__(self, embedding_model: str, url: str, collection_name: str):
42+
def __init__(
43+
self,
44+
embedding_model: str,
45+
url: str,
46+
collection_name: str,
47+
embedding_length: int,
48+
):
4149
"""
4250
Initialize a PGVectorProvider for use with PostgreSQL.
4351
@@ -48,11 +56,10 @@ def __init__(self, embedding_model: str, url: str, collection_name: str):
4856
"""
4957
super().__init__(embedding_model)
5058

51-
self.db = PGVector(
52-
connection=url,
53-
collection_name=collection_name,
54-
embeddings=self.embeddings,
55-
)
59+
engine = PGEngine.from_connection_string(url)
60+
engine.init_vectorstore_table(collection_name, embedding_length)
61+
62+
self.db = PGVectorStore.create_sync(engine, self.embeddings, collection_name)
5663

5764
parsed = urlparse(url)
5865
postgres_location = (

0 commit comments

Comments
 (0)