Skip to content

Commit 8ea6668

Browse files
committed
add embedding model arg to db providers
1 parent 9037611 commit 8ea6668

File tree

9 files changed

+67
-28
lines changed

9 files changed

+67
-28
lines changed

.env

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ WEB_SOURCES=["https://ai-on-openshift.io/getting-started/openshift/", "https://a
1414
CHUNK_SIZE=1024
1515
CHUNK_OVERLAP=40
1616
DB_TYPE=DRYRUN
17+
EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2
1718

1819
# === Redis ===
1920
REDIS_URL=redis://localhost:6379

config.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,36 +62,47 @@ def _init_db_provider(db_type: str) -> DBProvider:
6262
"""
6363
Initialize the correct DBProvider subclass based on DB_TYPE.
6464
"""
65+
get = Config._get_required_env_var
6566
db_type = db_type.upper()
67+
embedding_model = get("EMBEDDING_MODEL")
6668

6769
if db_type == "REDIS":
68-
url = Config._get_required_env_var("REDIS_URL")
70+
url = get("REDIS_URL")
6971
index = os.getenv("REDIS_INDEX", "docs")
7072
schema = os.getenv("REDIS_SCHEMA", "redis_schema.yaml")
71-
return RedisProvider(url, index, schema)
73+
return RedisProvider(embedding_model, url, index, schema)
7274

7375
elif db_type == "ELASTIC":
74-
url = Config._get_required_env_var("ELASTIC_URL")
75-
password = Config._get_required_env_var("ELASTIC_PASSWORD")
76+
url = get("ELASTIC_URL")
77+
password = get("ELASTIC_PASSWORD")
7678
index = os.getenv("ELASTIC_INDEX", "docs")
7779
user = os.getenv("ELASTIC_USER", "elastic")
78-
return ElasticProvider(url, password, index, user)
80+
return ElasticProvider(embedding_model, url, password, index, user)
7981

8082
elif db_type == "PGVECTOR":
81-
url = Config._get_required_env_var("PGVECTOR_URL")
82-
collection = Config._get_required_env_var("PGVECTOR_COLLECTION_NAME")
83-
return PGVectorProvider(url, collection)
83+
url = get("PGVECTOR_URL")
84+
collection = get("PGVECTOR_COLLECTION_NAME")
85+
return PGVectorProvider(embedding_model, url, collection)
8486

8587
elif db_type == "SQLSERVER":
86-
return SQLServerProvider() # Handles its own env var loading
88+
host = get("SQLSERVER_HOST")
89+
port = get("SQLSERVER_PORT")
90+
user = get("SQLSERVER_USER")
91+
password = get("SQLSERVER_PASSWORD")
92+
database = get("SQLSERVER_DB")
93+
table = get("SQLSERVER_TABLE")
94+
driver = get("SQLSERVER_DRIVER")
95+
return SQLServerProvider(
96+
embedding_model, host, port, user, password, database, table, driver
97+
)
8798

8899
elif db_type == "QDRANT":
89-
url = Config._get_required_env_var("QDRANT_URL")
90-
collection = Config._get_required_env_var("QDRANT_COLLECTION")
91-
return QdrantProvider(url, collection)
100+
url = get("QDRANT_URL")
101+
collection = get("QDRANT_COLLECTION")
102+
return QdrantProvider(embedding_model, url, collection)
92103

93104
elif db_type == "DRYRUN":
94-
return DryRunProvider()
105+
return DryRunProvider(embedding_model)
95106

96107
raise ValueError(f"Unsupported DB_TYPE '{db_type}'")
97108

@@ -135,9 +146,11 @@ def load() -> "Config":
135146
except json.JSONDecodeError as e:
136147
raise ValueError(f"Invalid REPO_SOURCES JSON: {e}") from e
137148

138-
# Misc
149+
# Embedding settings
139150
chunk_size = int(get("CHUNK_SIZE"))
140151
chunk_overlap = int(get("CHUNK_OVERLAP"))
152+
153+
# Misc
141154
temp_dir = get("TEMP_DIR")
142155

143156
return Config(

vector_db/db_provider.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ class DBProvider(ABC):
1010
"""
1111
Abstract base class for vector DB providers.
1212
Subclasses must implement `add_documents`.
13+
14+
Args:
15+
embedding_model (str): Embedding model to use
1316
"""
1417

15-
def __init__(self) -> None:
16-
self.embeddings: Embeddings = HuggingFaceEmbeddings()
18+
def __init__(self, embedding_model: str) -> None:
19+
self.embeddings: Embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
1720

1821
@abstractmethod
1922
def add_documents(self, docs: List[Document]) -> None:

vector_db/dryrun_provider.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ class DryRunProvider(DBProvider):
1313
chunked documents to stdout. It is useful for debugging document loading,
1414
chunking, and metadata before committing to a real embedding operation.
1515
16+
Args:
17+
embedding_model (str): Embedding model to use
18+
1619
Example:
1720
>>> from vector_db.dry_run_provider import DryRunProvider
18-
>>> provider = DryRunProvider()
21+
>>> provider = DryRunProvider("sentence-transformers/all-mpnet-base-v2")
1922
>>> provider.add_documents(docs) # docs is a List[Document]
2023
"""
2124

22-
def __init__(self):
23-
super().__init__() # ensures embeddings are initialized
25+
def __init__(self, embedding_model: str):
26+
super().__init__(embedding_model)
2427

2528
def add_documents(self, docs: List[Document]) -> None:
2629
"""

vector_db/elastic_provider.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ class ElasticProvider(DBProvider):
1414
Elasticsearch-based vector DB provider using LangChain's ElasticsearchStore.
1515
1616
Args:
17+
embedding_model (str): Embedding model to use
1718
url (str): Full URL to the Elasticsearch cluster (e.g. http://localhost:9200)
1819
password (str): Authentication password for the cluster
1920
index (str): Index name to use for vector storage
2021
user (str): Username for Elasticsearch (default: "elastic")
2122
2223
Example:
2324
>>> provider = ElasticProvider(
25+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2426
... url="http://localhost:9200",
2527
... password="changeme",
2628
... index="docs",
@@ -29,8 +31,10 @@ class ElasticProvider(DBProvider):
2931
>>> provider.add_documents(chunks)
3032
"""
3133

32-
def __init__(self, url: str, password: str, index: str, user: str):
33-
super().__init__()
34+
def __init__(
35+
self, embedding_model: str, url: str, password: str, index: str, user: str
36+
):
37+
super().__init__(embedding_model)
3438

3539
self.db = ElasticsearchStore(
3640
embedding=self.embeddings,

vector_db/pgvector_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,21 @@ class PGVectorProvider(DBProvider):
1818
document embeddings in a PostgreSQL-compatible backend with pgvector enabled.
1919
2020
Args:
21+
embedding_model (str): Embedding model to use
2122
url (str): PostgreSQL connection string (e.g. postgresql://user:pass@host:5432/db)
2223
collection_name (str): Name of the pgvector table or collection
2324
2425
Example:
2526
>>> provider = PGVectorProvider(
27+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2628
... url="postgresql://user:pass@localhost:5432/mydb",
2729
... collection_name="documents"
2830
... )
2931
>>> provider.add_documents(chunks)
3032
"""
3133

32-
def __init__(self, url: str, collection_name: str):
33-
super().__init__()
34+
def __init__(self, embedding_model: str, url: str, collection_name: str):
35+
super().__init__(embedding_model)
3436

3537
self.db = PGVector(
3638
connection=url,

vector_db/qdrant_provider.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class QdrantProvider(DBProvider):
1616
Qdrant-based vector DB provider using LangChain's QdrantVectorStore.
1717
1818
Args:
19+
embedding_model (str): Embedding model to use
1920
url (str): Base URL of the Qdrant service (e.g., http://localhost:6333)
2021
collection (str): Name of the vector collection to use or create
2122
api_key (Optional[str]): API key if authentication is required (optional)
@@ -24,15 +25,22 @@ class QdrantProvider(DBProvider):
2425
2526
Example:
2627
>>> provider = QdrantProvider(
28+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2729
... url="http://localhost:6333",
2830
... collection="embedded_docs",
2931
... api_key=None
3032
... )
3133
>>> provider.add_documents(docs)
3234
"""
3335

34-
def __init__(self, url: str, collection: str, api_key: Optional[str] = None):
35-
super().__init__()
36+
def __init__(
37+
self,
38+
embedding_model: str,
39+
url: str,
40+
collection: str,
41+
api_key: Optional[str] = None,
42+
):
43+
super().__init__(embedding_model)
3644
self.collection = collection
3745
self.url = url
3846

vector_db/redis_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class RedisProvider(DBProvider):
1515
Redis-based vector DB provider using RediSearch and LangChain's Redis integration.
1616
1717
Args:
18+
embedding_model (str): Embedding model to use
1819
url (str): Redis connection string (e.g. redis://localhost:6379)
1920
index (str): RediSearch index name (must be provided via .env)
2021
schema (str): Path to RediSearch schema YAML file (must be provided via .env)
@@ -24,15 +25,16 @@ class RedisProvider(DBProvider):
2425
2526
Example:
2627
>>> provider = RedisProvider(
28+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2729
... url="redis://localhost:6379",
2830
... index="docs",
2931
... schema="redis_schema.yaml"
3032
... )
3133
>>> provider.add_documents(chunks)
3234
"""
3335

34-
def __init__(self, url: str, index: str, schema: str):
35-
super().__init__()
36+
def __init__(self, embedding_model: str, url: str, index: str, schema: str):
37+
super().__init__(embedding_model)
3638
self.url = url
3739
self.index = index
3840
self.schema = schema

vector_db/sqlserver_provider.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class SQLServerProvider(DBProvider):
1515
SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore.
1616
1717
Args:
18+
embedding_model (str): Embedding model to use
1819
host (str): Hostname of the SQL Server
1920
port (str): Port number
2021
user (str): SQL login username
@@ -25,6 +26,7 @@ class SQLServerProvider(DBProvider):
2526
2627
Example:
2728
>>> provider = SQLServerProvider(
29+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2830
... host="localhost",
2931
... port="1433",
3032
... user="sa",
@@ -38,6 +40,7 @@ class SQLServerProvider(DBProvider):
3840

3941
def __init__(
4042
self,
43+
embedding_model: str,
4144
host: str,
4245
port: str,
4346
user: str,
@@ -46,7 +49,7 @@ def __init__(
4649
table: str,
4750
driver: str,
4851
) -> None:
49-
super().__init__()
52+
super().__init__(embedding_model)
5053

5154
self.host = host
5255
self.port = port

0 commit comments

Comments
 (0)