Skip to content

Commit 7080738

Browse files
committed
update db providers to only need embedding model and calculate dimensionality for better modularity
1 parent d155f72 commit 7080738

File tree

9 files changed

+111
-135
lines changed

9 files changed

+111
-135
lines changed

.env

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ CHUNK_SIZE=1024
1515
CHUNK_OVERLAP=40
1616
DB_TYPE=DRYRUN
1717
EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2
18-
EMBEDDING_LENGTH=768
1918

2019
# === Redis ===
2120
REDIS_URL=redis://localhost:6379

config.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, List
66

77
from dotenv import load_dotenv
8+
from langchain_huggingface import HuggingFaceEmbeddings
89

910
from vector_db.db_provider import DBProvider
1011
from vector_db.dryrun_provider import DryRunProvider
@@ -108,40 +109,37 @@ def _init_db_provider(db_type: str) -> DBProvider:
108109
"""
109110
get = Config._get_required_env_var
110111
db_type = db_type.upper()
111-
embedding_model = get("EMBEDDING_MODEL")
112-
embedding_length = int(get("EMBEDDING_LENGTH"))
112+
embeddings = HuggingFaceEmbeddings(model_name=get("EMBEDDING_MODEL"))
113113

114114
if db_type == "REDIS":
115115
url = get("REDIS_URL")
116116
index = os.getenv("REDIS_INDEX", "docs")
117-
return RedisProvider(embedding_model, url, index)
117+
return RedisProvider(embeddings, url, index)
118118

119119
elif db_type == "ELASTIC":
120120
url = get("ELASTIC_URL")
121121
password = get("ELASTIC_PASSWORD")
122122
index = os.getenv("ELASTIC_INDEX", "docs")
123123
user = os.getenv("ELASTIC_USER", "elastic")
124-
return ElasticProvider(embedding_model, url, password, index, user)
124+
return ElasticProvider(embeddings, url, password, index, user)
125125

126126
elif db_type == "PGVECTOR":
127127
url = get("PGVECTOR_URL")
128128
collection = get("PGVECTOR_COLLECTION_NAME")
129-
return PGVectorProvider(embedding_model, url, collection, embedding_length)
129+
return PGVectorProvider(embeddings, url, collection)
130130

131131
elif db_type == "MSSQL":
132132
connection_string = get("MSSQL_CONNECTION_STRING")
133133
table = get("MSSQL_TABLE")
134-
return MSSQLProvider(
135-
embedding_model, connection_string, table, embedding_length
136-
)
134+
return MSSQLProvider(embeddings, connection_string, table)
137135

138136
elif db_type == "QDRANT":
139137
url = get("QDRANT_URL")
140138
collection = get("QDRANT_COLLECTION")
141-
return QdrantProvider(embedding_model, url, collection)
139+
return QdrantProvider(embeddings, url, collection)
142140

143141
elif db_type == "DRYRUN":
144-
return DryRunProvider(embedding_model)
142+
return DryRunProvider(embeddings)
145143

146144
raise ValueError(f"Unsupported DB_TYPE '{db_type}'")
147145

vector_db/db_provider.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,42 +11,43 @@ class DBProvider(ABC):
1111
Abstract base class for vector database providers.
1212
1313
This class standardizes how vector databases are initialized and how documents
14-
are added to them. All concrete implementations (e.g., Qdrant, FAISS) must
14+
are added to them. All concrete implementations (e.g., Qdrant, Redis) must
1515
subclass `DBProvider` and implement the `add_documents()` method.
1616
1717
Attributes:
18-
embeddings (Embeddings): An instance of HuggingFace embeddings based on the
19-
specified model.
18+
embeddings (HuggingFaceEmbeddings): An instance of HuggingFace embeddings.
19+
embedding_length (int): Dimensionality of the embedding vector.
2020
2121
Args:
22-
embedding_model (str): HuggingFace-compatible model name to be used for computing
23-
dense vector embeddings for documents.
22+
embeddings (HuggingFaceEmbeddings): A preconfigured HuggingFaceEmbeddings instance.
2423
2524
Example:
2625
>>> class MyProvider(DBProvider):
2726
... def add_documents(self, docs):
28-
... print(f"Would add {len(docs)} docs with model {self.embeddings.model_name}")
27+
... print(f"Would add {len(docs)} docs with vector size {self.embedding_length}")
2928
30-
>>> provider = MyProvider("BAAI/bge-small-en")
29+
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
30+
>>> provider = MyProvider(embeddings)
3131
>>> provider.add_documents([Document(page_content="Hello")])
3232
"""
3333

34-
def __init__(self, embedding_model: str) -> None:
34+
def __init__(self, embeddings: HuggingFaceEmbeddings) -> None:
3535
"""
36-
Initialize a DB provider with a specific embedding model.
36+
Initialize a DB provider with a HuggingFaceEmbeddings instance.
3737
3838
Args:
39-
embedding_model (str): The HuggingFace model name to be used for generating embeddings.
39+
embeddings (HuggingFaceEmbeddings): The embeddings object used for vectorization.
4040
"""
41-
self.embeddings: Embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
41+
self.embeddings: HuggingFaceEmbeddings = embeddings
42+
self.embedding_length: int = len(self.embeddings.embed_query("query"))
4243

4344
@abstractmethod
4445
def add_documents(self, docs: List[Document]) -> None:
4546
"""
4647
Add documents to the vector database.
4748
4849
This method must be implemented by subclasses to define how documents
49-
(with or without precomputed embeddings) are stored in the backend vector DB.
50+
are embedded and stored in the backend vector DB.
5051
5152
Args:
5253
docs (List[Document]): A list of LangChain `Document` objects to be embedded and added.

vector_db/dryrun_provider.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List
22

33
from langchain_core.documents import Document
4+
from langchain_huggingface import HuggingFaceEmbeddings
45

56
from vector_db.db_provider import DBProvider
67

@@ -9,36 +10,35 @@ class DryRunProvider(DBProvider):
910
"""
1011
A mock vector DB provider for debugging document loading and chunking.
1112
12-
`DryRunProvider` does not persist any documents or perform embedding operations.
13-
Instead, it prints a preview of the documents and their metadata to stdout,
14-
allowing users to validate chunking, structure, and metadata before pushing
15-
to a production vector store.
16-
17-
Useful for development, testing, or understanding how your documents are
18-
being processed.
13+
`DryRunProvider` does not persist any documents or perform actual embedding.
14+
It prints a preview of the documents and their metadata to stdout, allowing users
15+
to validate chunking, structure, and metadata before pushing to a production vector store.
1916
2017
Attributes:
21-
embeddings (Embeddings): HuggingFace embedding model for compatibility.
18+
embeddings (HuggingFaceEmbeddings): HuggingFace embedding instance, used for interface consistency.
19+
embedding_length (int): Dimensionality of embeddings (computed for validation, not used).
2220
2321
Args:
24-
embedding_model (str): The model name to initialize HuggingFaceEmbeddings.
25-
Used only for compatibility — no embeddings are generated.
22+
embeddings (HuggingFaceEmbeddings): A HuggingFace embedding model instance.
2623
2724
Example:
2825
>>> from langchain_core.documents import Document
29-
>>> provider = DryRunProvider("BAAI/bge-small-en")
26+
>>> from langchain_huggingface import HuggingFaceEmbeddings
27+
>>> from vector_db.dryrun_provider import DryRunProvider
28+
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
29+
>>> provider = DryRunProvider(embeddings)
3030
>>> docs = [Document(page_content="Hello world", metadata={"source": "test.txt"})]
3131
>>> provider.add_documents(docs)
3232
"""
3333

34-
def __init__(self, embedding_model: str):
34+
def __init__(self, embeddings: HuggingFaceEmbeddings):
3535
"""
3636
Initialize the dry run provider with a placeholder embedding model.
3737
3838
Args:
39-
embedding_model (str): The name of the embedding model (used for interface consistency).
39+
embeddings (HuggingFaceEmbeddings): A HuggingFace embedding model (used for compatibility).
4040
"""
41-
super().__init__(embedding_model)
41+
super().__init__(embeddings)
4242

4343
def add_documents(self, docs: List[Document]) -> None:
4444
"""

vector_db/elastic_provider.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from langchain_core.documents import Document
55
from langchain_elasticsearch.vectorstores import ElasticsearchStore
6+
from langchain_huggingface import HuggingFaceEmbeddings
67

78
from vector_db.db_provider import DBProvider
89

@@ -13,25 +14,27 @@ class ElasticProvider(DBProvider):
1314
"""
1415
Vector database provider backed by Elasticsearch using LangChain's ElasticsearchStore.
1516
16-
This provider allows storing and querying vectorized documents in an Elasticsearch
17-
cluster. Documents are embedded using a HuggingFace model and stored with associated
18-
metadata in the specified index.
17+
This provider stores and queries vectorized documents in an Elasticsearch cluster.
18+
Documents are embedded using the provided HuggingFace embeddings model and stored
19+
with associated metadata in the specified index.
1920
2021
Attributes:
21-
db (ElasticsearchStore): LangChain-compatible wrapper around Elasticsearch vector storage.
22-
embeddings (Embeddings): HuggingFace embedding model for generating document vectors.
22+
db (ElasticsearchStore): LangChain-compatible Elasticsearch vector store.
23+
embeddings (HuggingFaceEmbeddings): HuggingFace embedding model instance.
2324
2425
Args:
25-
embedding_model (str): HuggingFace model name for computing embeddings.
26-
url (str): Full URL to the Elasticsearch cluster (e.g. "http://localhost:9200").
26+
embeddings (HuggingFaceEmbeddings): Pre-initialized embeddings instance.
27+
url (str): Full URL to the Elasticsearch cluster (e.g., "http://localhost:9200").
2728
password (str): Password for the Elasticsearch user.
2829
index (str): The index name where documents will be stored.
2930
user (str): Elasticsearch username (default is typically "elastic").
3031
3132
Example:
33+
>>> from langchain_huggingface import HuggingFaceEmbeddings
3234
>>> from vector_db.elastic_provider import ElasticProvider
35+
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
3336
>>> provider = ElasticProvider(
34-
... embedding_model="BAAI/bge-small-en",
37+
... embeddings=embeddings,
3538
... url="http://localhost:9200",
3639
... password="changeme",
3740
... index="rag-docs",
@@ -42,7 +45,7 @@ class ElasticProvider(DBProvider):
4245

4346
def __init__(
4447
self,
45-
embedding_model: str,
48+
embeddings: HuggingFaceEmbeddings,
4649
url: str,
4750
password: str,
4851
index: str,
@@ -52,13 +55,13 @@ def __init__(
5255
Initialize an Elasticsearch-based vector DB provider.
5356
5457
Args:
55-
embedding_model (str): The model name for computing embeddings.
58+
embeddings (HuggingFaceEmbeddings): HuggingFace embeddings instance.
5659
url (str): Full URL of the Elasticsearch service.
5760
password (str): Elasticsearch user's password.
5861
index (str): Name of the Elasticsearch index to use.
5962
user (str): Elasticsearch username (e.g., "elastic").
6063
"""
61-
super().__init__(embedding_model)
64+
super().__init__(embeddings)
6265

6366
self.db = ElasticsearchStore(
6467
embedding=self.embeddings,
@@ -74,8 +77,8 @@ def add_documents(self, docs: List[Document]) -> None:
7477
"""
7578
Add a batch of LangChain documents to the Elasticsearch index.
7679
77-
Each document will be embedded using the configured model and stored
78-
in the specified index with any associated metadata.
80+
Each document is embedded using the provided model and stored
81+
in the specified index with its associated metadata.
7982
8083
Args:
8184
docs (List[Document]): List of documents to index.

vector_db/mssql_provider.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pyodbc
66
from langchain_core.documents import Document
7+
from langchain_huggingface import HuggingFaceEmbeddings
78
from langchain_sqlserver import SQLServer_VectorStore
89

910
from vector_db.db_provider import DBProvider
@@ -16,49 +17,45 @@ class MSSQLProvider(DBProvider):
1617
SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore integration.
1718
1819
This provider connects to a Microsoft SQL Server instance using a full ODBC connection string,
19-
and stores document embeddings in a specified table. If the target database does not exist,
20-
it will be created automatically.
20+
and stores document embeddings in a specified table. The target database will be created if it
21+
does not already exist.
2122
2223
Attributes:
2324
db (SQLServer_VectorStore): Underlying LangChain-compatible vector store.
2425
connection_string (str): Full ODBC connection string to the SQL Server instance.
2526
2627
Args:
27-
embedding_model (str): HuggingFace-compatible embedding model to use.
28+
embeddings (HuggingFaceEmbeddings): Pre-initialized embeddings instance.
2829
connection_string (str): Full ODBC connection string (including target DB).
2930
table (str): Table name to store vector embeddings.
30-
embedding_length (int): Dimensionality of the embeddings (e.g., 768 for all-mpnet-base-v2).
3131
3232
Example:
33+
>>> from langchain_huggingface import HuggingFaceEmbeddings
34+
>>> from vector_db.mssql_provider import MSSQLProvider
35+
>>> embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
3336
>>> provider = MSSQLProvider(
34-
... embedding_model="BAAI/bge-large-en-v1.5",
37+
... embeddings=embeddings,
3538
... connection_string="Driver={ODBC Driver 18 for SQL Server};Server=localhost,1433;Database=docs;UID=sa;PWD=StrongPassword!;TrustServerCertificate=yes;Encrypt=no;",
3639
... table="embedded_docs",
37-
... embedding_length=768,
3840
... )
3941
>>> provider.add_documents(docs)
4042
"""
4143

4244
def __init__(
4345
self,
44-
embedding_model: str,
46+
embeddings: HuggingFaceEmbeddings,
4547
connection_string: str,
4648
table: str,
47-
embedding_length: int,
4849
) -> None:
4950
"""
5051
Initialize the MSSQLProvider.
5152
5253
Args:
53-
embedding_model (str): HuggingFace-compatible embedding model to use for generating embeddings.
54+
embeddings (HuggingFaceEmbeddings): HuggingFace-compatible embedding model instance.
5455
connection_string (str): Full ODBC connection string including target database name.
5556
table (str): Table name to store document embeddings.
56-
embedding_length (int): Size of the embeddings (number of dimensions).
57-
58-
Raises:
59-
RuntimeError: If the database specified in the connection string cannot be found or created.
6057
"""
61-
super().__init__(embedding_model)
58+
super().__init__(embeddings)
6259

6360
self.connection_string = connection_string
6461
self.table = table
@@ -77,36 +74,18 @@ def __init__(
7774
connection_string=self.connection_string,
7875
embedding_function=self.embeddings,
7976
table_name=self.table,
80-
embedding_length=embedding_length,
77+
embedding_length=self.embedding_length,
8178
)
8279

8380
def _extract_server_address(self) -> str:
84-
"""
85-
Extract the server address (host,port) from the connection string.
86-
87-
Returns:
88-
str: The server address portion ("host,port") or "unknown" if not found.
89-
"""
9081
match = re.search(r"Server=([^;]+)", self.connection_string, re.IGNORECASE)
9182
return match.group(1) if match else "unknown"
9283

9384
def _extract_database_name(self) -> Optional[str]:
94-
"""
95-
Extract the database name from the connection string.
96-
97-
Returns:
98-
str: Database name if found, else None.
99-
"""
10085
match = re.search(r"Database=([^;]+)", self.connection_string, re.IGNORECASE)
10186
return match.group(1) if match else None
10287

10388
def _build_connection_string_for_master(self) -> str:
104-
"""
105-
Modify the connection string to point to the 'master' database.
106-
107-
Returns:
108-
str: Modified connection string.
109-
"""
11089
parts = self.connection_string.split(";")
11190
updated_parts = [
11291
"Database=master" if p.strip().lower().startswith("database=") else p
@@ -116,12 +95,6 @@ def _build_connection_string_for_master(self) -> str:
11695
return ";".join(updated_parts) + ";"
11796

11897
def _ensure_database_exists(self) -> None:
119-
"""
120-
Connect to the SQL Server master database and create the target database if missing.
121-
122-
Raises:
123-
RuntimeError: If the database cannot be created or accessed.
124-
"""
12598
database = self._extract_database_name()
12699
if not database:
127100
raise RuntimeError("No database name found in connection string.")

0 commit comments

Comments
 (0)