Skip to content

Commit 478e34b

Browse files
authored
Variable embedder (#3)
* move log level parsing to its own function * tidy config parsing * add embedding model arg to db providers * fix logging in config
1 parent cdc365e commit 478e34b

File tree

9 files changed

+87
-47
lines changed

9 files changed

+87
-47
lines changed

.env

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ WEB_SOURCES=["https://ai-on-openshift.io/getting-started/openshift/", "https://a
1414
CHUNK_SIZE=1024
1515
CHUNK_OVERLAP=40
1616
DB_TYPE=DRYRUN
17+
EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2
1718

1819
# === Redis ===
1920
REDIS_URL=redis://localhost:6379

config.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -42,41 +42,67 @@ def _get_required_env_var(key: str) -> str:
4242
raise ValueError(f"{key} environment variable is required.")
4343
return value
4444

45+
@staticmethod
46+
def _parse_log_level(log_level_name: str) -> int:
47+
log_levels = {
48+
"DEBUG": logging.DEBUG,
49+
"INFO": logging.INFO,
50+
"WARNING": logging.WARNING,
51+
"ERROR": logging.ERROR,
52+
"CRITICAL": logging.CRITICAL,
53+
}
54+
if log_level_name not in log_levels:
55+
raise ValueError(
56+
f"Invalid LOG_LEVEL: '{log_level_name}'. Must be one of: {', '.join(log_levels.keys())}"
57+
)
58+
return log_levels[log_level_name]
59+
4560
@staticmethod
4661
def _init_db_provider(db_type: str) -> DBProvider:
4762
"""
4863
Initialize the correct DBProvider subclass based on DB_TYPE.
4964
"""
65+
get = Config._get_required_env_var
5066
db_type = db_type.upper()
67+
embedding_model = get("EMBEDDING_MODEL")
5168

5269
if db_type == "REDIS":
53-
url = Config._get_required_env_var("REDIS_URL")
70+
url = get("REDIS_URL")
5471
index = os.getenv("REDIS_INDEX", "docs")
5572
schema = os.getenv("REDIS_SCHEMA", "redis_schema.yaml")
56-
return RedisProvider(url, index, schema)
73+
return RedisProvider(embedding_model, url, index, schema)
5774

5875
elif db_type == "ELASTIC":
59-
url = Config._get_required_env_var("ELASTIC_URL")
60-
password = Config._get_required_env_var("ELASTIC_PASSWORD")
76+
url = get("ELASTIC_URL")
77+
password = get("ELASTIC_PASSWORD")
6178
index = os.getenv("ELASTIC_INDEX", "docs")
6279
user = os.getenv("ELASTIC_USER", "elastic")
63-
return ElasticProvider(url, password, index, user)
80+
return ElasticProvider(embedding_model, url, password, index, user)
6481

6582
elif db_type == "PGVECTOR":
66-
url = Config._get_required_env_var("PGVECTOR_URL")
67-
collection = Config._get_required_env_var("PGVECTOR_COLLECTION_NAME")
68-
return PGVectorProvider(url, collection)
83+
url = get("PGVECTOR_URL")
84+
collection = get("PGVECTOR_COLLECTION_NAME")
85+
return PGVectorProvider(embedding_model, url, collection)
6986

7087
elif db_type == "SQLSERVER":
71-
return SQLServerProvider() # Handles its own env var loading
88+
host = get("SQLSERVER_HOST")
89+
port = get("SQLSERVER_PORT")
90+
user = get("SQLSERVER_USER")
91+
password = get("SQLSERVER_PASSWORD")
92+
database = get("SQLSERVER_DB")
93+
table = get("SQLSERVER_TABLE")
94+
driver = get("SQLSERVER_DRIVER")
95+
return SQLServerProvider(
96+
embedding_model, host, port, user, password, database, table, driver
97+
)
7298

7399
elif db_type == "QDRANT":
74-
url = Config._get_required_env_var("QDRANT_URL")
75-
collection = Config._get_required_env_var("QDRANT_COLLECTION")
76-
return QdrantProvider(url, collection)
100+
url = get("QDRANT_URL")
101+
collection = get("QDRANT_COLLECTION")
102+
return QdrantProvider(embedding_model, url, collection)
77103

78104
elif db_type == "DRYRUN":
79-
return DryRunProvider()
105+
return DryRunProvider(embedding_model)
80106

81107
raise ValueError(f"Unsupported DB_TYPE '{db_type}'")
82108

@@ -99,44 +125,32 @@ def load() -> "Config":
99125
get = Config._get_required_env_var
100126

101127
# Initialize logger
102-
log_level_name = get("LOG_LEVEL").lower()
103-
log_levels = {
104-
"debug": 10,
105-
"info": 20,
106-
"warning": 30,
107-
"error": 40,
108-
"critical": 50,
109-
}
110-
if log_level_name not in log_levels:
111-
raise ValueError(
112-
f"Invalid LOG_LEVEL: '{log_level_name}'. Must be one of: {', '.join(log_levels)}"
113-
)
114-
log_level = log_levels[log_level_name]
115-
logging.basicConfig(level=log_level)
128+
log_level = get("LOG_LEVEL").upper()
129+
logging.basicConfig(level=Config._parse_log_level(log_level))
116130
logger = logging.getLogger(__name__)
117-
logger.debug("Logging initialized at level: %s", log_level_name.upper())
131+
logger.debug("Logging initialized at level: %s", log_level)
118132

119133
# Initialize db
120134
db_type = get("DB_TYPE")
121135
db_provider = Config._init_db_provider(db_type)
122136

123137
# Web URLs
124-
web_sources_raw = get("WEB_SOURCES")
125138
try:
126-
web_sources = json.loads(web_sources_raw)
139+
web_sources = json.loads(get("WEB_SOURCES"))
127140
except json.JSONDecodeError as e:
128141
raise ValueError(f"WEB_SOURCES must be a valid JSON list: {e}")
129142

130143
# Repo sources
131-
repo_sources_json = get("REPO_SOURCES")
132144
try:
133-
repo_sources = json.loads(repo_sources_json)
145+
repo_sources = json.loads(get("REPO_SOURCES"))
134146
except json.JSONDecodeError as e:
135147
raise ValueError(f"Invalid REPO_SOURCES JSON: {e}") from e
136148

137-
# Misc
149+
# Embedding settings
138150
chunk_size = int(get("CHUNK_SIZE"))
139151
chunk_overlap = int(get("CHUNK_OVERLAP"))
152+
153+
# Misc
140154
temp_dir = get("TEMP_DIR")
141155

142156
return Config(

vector_db/db_provider.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ class DBProvider(ABC):
1010
"""
1111
Abstract base class for vector DB providers.
1212
Subclasses must implement `add_documents`.
13+
14+
Args:
15+
embedding_model (str): Embedding model to use
1316
"""
1417

15-
def __init__(self) -> None:
16-
self.embeddings: Embeddings = HuggingFaceEmbeddings()
18+
def __init__(self, embedding_model: str) -> None:
19+
self.embeddings: Embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
1720

1821
@abstractmethod
1922
def add_documents(self, docs: List[Document]) -> None:

vector_db/dryrun_provider.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ class DryRunProvider(DBProvider):
1313
chunked documents to stdout. It is useful for debugging document loading,
1414
chunking, and metadata before committing to a real embedding operation.
1515
16+
Args:
17+
embedding_model (str): Embedding model to use
18+
1619
Example:
1720
>>> from vector_db.dry_run_provider import DryRunProvider
18-
>>> provider = DryRunProvider()
21+
>>> provider = DryRunProvider("sentence-transformers/all-mpnet-base-v2")
1922
>>> provider.add_documents(docs) # docs is a List[Document]
2023
"""
2124

22-
def __init__(self):
23-
super().__init__() # ensures embeddings are initialized
25+
def __init__(self, embedding_model: str):
26+
super().__init__(embedding_model)
2427

2528
def add_documents(self, docs: List[Document]) -> None:
2629
"""

vector_db/elastic_provider.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ class ElasticProvider(DBProvider):
1414
Elasticsearch-based vector DB provider using LangChain's ElasticsearchStore.
1515
1616
Args:
17+
embedding_model (str): Embedding model to use
1718
url (str): Full URL to the Elasticsearch cluster (e.g. http://localhost:9200)
1819
password (str): Authentication password for the cluster
1920
index (str): Index name to use for vector storage
2021
user (str): Username for Elasticsearch (default: "elastic")
2122
2223
Example:
2324
>>> provider = ElasticProvider(
25+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2426
... url="http://localhost:9200",
2527
... password="changeme",
2628
... index="docs",
@@ -29,8 +31,10 @@ class ElasticProvider(DBProvider):
2931
>>> provider.add_documents(chunks)
3032
"""
3133

32-
def __init__(self, url: str, password: str, index: str, user: str):
33-
super().__init__()
34+
def __init__(
35+
self, embedding_model: str, url: str, password: str, index: str, user: str
36+
):
37+
super().__init__(embedding_model)
3438

3539
self.db = ElasticsearchStore(
3640
embedding=self.embeddings,

vector_db/pgvector_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,21 @@ class PGVectorProvider(DBProvider):
1818
document embeddings in a PostgreSQL-compatible backend with pgvector enabled.
1919
2020
Args:
21+
embedding_model (str): Embedding model to use
2122
url (str): PostgreSQL connection string (e.g. postgresql://user:pass@host:5432/db)
2223
collection_name (str): Name of the pgvector table or collection
2324
2425
Example:
2526
>>> provider = PGVectorProvider(
27+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2628
... url="postgresql://user:pass@localhost:5432/mydb",
2729
... collection_name="documents"
2830
... )
2931
>>> provider.add_documents(chunks)
3032
"""
3133

32-
def __init__(self, url: str, collection_name: str):
33-
super().__init__()
34+
def __init__(self, embedding_model: str, url: str, collection_name: str):
35+
super().__init__(embedding_model)
3436

3537
self.db = PGVector(
3638
connection=url,

vector_db/qdrant_provider.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class QdrantProvider(DBProvider):
1616
Qdrant-based vector DB provider using LangChain's QdrantVectorStore.
1717
1818
Args:
19+
embedding_model (str): Embedding model to use
1920
url (str): Base URL of the Qdrant service (e.g., http://localhost:6333)
2021
collection (str): Name of the vector collection to use or create
2122
api_key (Optional[str]): API key if authentication is required (optional)
@@ -24,15 +25,22 @@ class QdrantProvider(DBProvider):
2425
2526
Example:
2627
>>> provider = QdrantProvider(
28+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2729
... url="http://localhost:6333",
2830
... collection="embedded_docs",
2931
... api_key=None
3032
... )
3133
>>> provider.add_documents(docs)
3234
"""
3335

34-
def __init__(self, url: str, collection: str, api_key: Optional[str] = None):
35-
super().__init__()
36+
def __init__(
37+
self,
38+
embedding_model: str,
39+
url: str,
40+
collection: str,
41+
api_key: Optional[str] = None,
42+
):
43+
super().__init__(embedding_model)
3644
self.collection = collection
3745
self.url = url
3846

vector_db/redis_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class RedisProvider(DBProvider):
1515
Redis-based vector DB provider using RediSearch and LangChain's Redis integration.
1616
1717
Args:
18+
embedding_model (str): Embedding model to use
1819
url (str): Redis connection string (e.g. redis://localhost:6379)
1920
index (str): RediSearch index name (must be provided via .env)
2021
schema (str): Path to RediSearch schema YAML file (must be provided via .env)
@@ -24,15 +25,16 @@ class RedisProvider(DBProvider):
2425
2526
Example:
2627
>>> provider = RedisProvider(
28+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2729
... url="redis://localhost:6379",
2830
... index="docs",
2931
... schema="redis_schema.yaml"
3032
... )
3133
>>> provider.add_documents(chunks)
3234
"""
3335

34-
def __init__(self, url: str, index: str, schema: str):
35-
super().__init__()
36+
def __init__(self, embedding_model: str, url: str, index: str, schema: str):
37+
super().__init__(embedding_model)
3638
self.url = url
3739
self.index = index
3840
self.schema = schema

vector_db/sqlserver_provider.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class SQLServerProvider(DBProvider):
1515
SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore.
1616
1717
Args:
18+
embedding_model (str): Embedding model to use
1819
host (str): Hostname of the SQL Server
1920
port (str): Port number
2021
user (str): SQL login username
@@ -25,6 +26,7 @@ class SQLServerProvider(DBProvider):
2526
2627
Example:
2728
>>> provider = SQLServerProvider(
29+
... embedding_model="sentence-transformers/all-mpnet-base-v2",
2830
... host="localhost",
2931
... port="1433",
3032
... user="sa",
@@ -38,6 +40,7 @@ class SQLServerProvider(DBProvider):
3840

3941
def __init__(
4042
self,
43+
embedding_model: str,
4144
host: str,
4245
port: str,
4346
user: str,
@@ -46,7 +49,7 @@ def __init__(
4649
table: str,
4750
driver: str,
4851
) -> None:
49-
super().__init__()
52+
super().__init__(embedding_model)
5053

5154
self.host = host
5255
self.port = port

0 commit comments

Comments
 (0)