Skip to content

Commit 9136bd5

Browse files
committed
update mssql provider
1 parent 65014e3 commit 9136bd5

File tree

4 files changed

+168
-162
lines changed

4 files changed

+168
-162
lines changed

.env

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ CHUNK_SIZE=1024
1515
CHUNK_OVERLAP=40
1616
DB_TYPE=DRYRUN
1717
EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2
18+
EMBEDDING_LENGTH=768
1819

1920
# === Redis ===
2021
REDIS_URL=redis://localhost:6379
@@ -32,13 +33,8 @@ PGVECTOR_URL=postgresql://user:pass@localhost:5432/mydb
3233
PGVECTOR_COLLECTION_NAME=documents
3334

3435
# === SQL Server ===
35-
SQLSERVER_HOST=localhost
36-
SQLSERVER_PORT=1433
37-
SQLSERVER_USER=sa
38-
SQLSERVER_PASSWORD=StrongPassword!
39-
SQLSERVER_DB=docs
40-
SQLSERVER_TABLE=vector_table
41-
SQLSERVER_DRIVER=ODBC Driver 18 for SQL Server
36+
MSSQL_CONNECTION_STRING="Driver={ODBC Driver 18 for SQL Server};Server=localhost,1433;Database=embeddings;UID=sa;PWD=StrongPassword!;TrustServerCertificate=yes;Encrypt=no;"
37+
MSSQL_TABLE=docs
4238

4339
# === Qdrant ===
4440
QDRANT_URL=http://localhost:6333

config.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from vector_db.db_provider import DBProvider
1010
from vector_db.dryrun_provider import DryRunProvider
1111
from vector_db.elastic_provider import ElasticProvider
12+
from vector_db.mssql_provider import MSSQLProvider
1213
from vector_db.pgvector_provider import PGVectorProvider
1314
from vector_db.qdrant_provider import QdrantProvider
1415
from vector_db.redis_provider import RedisProvider
15-
from vector_db.sqlserver_provider import SQLServerProvider
1616

1717

1818
@dataclass
@@ -109,6 +109,7 @@ def _init_db_provider(db_type: str) -> DBProvider:
109109
get = Config._get_required_env_var
110110
db_type = db_type.upper()
111111
embedding_model = get("EMBEDDING_MODEL")
112+
embedding_length = get("EMBEDDING_LENGTH")
112113

113114
if db_type == "REDIS":
114115
url = get("REDIS_URL")
@@ -128,16 +129,11 @@ def _init_db_provider(db_type: str) -> DBProvider:
128129
collection = get("PGVECTOR_COLLECTION_NAME")
129130
return PGVectorProvider(embedding_model, url, collection)
130131

131-
elif db_type == "SQLSERVER":
132-
host = get("SQLSERVER_HOST")
133-
port = get("SQLSERVER_PORT")
134-
user = get("SQLSERVER_USER")
135-
password = get("SQLSERVER_PASSWORD")
136-
database = get("SQLSERVER_DB")
137-
table = get("SQLSERVER_TABLE")
138-
driver = get("SQLSERVER_DRIVER")
139-
return SQLServerProvider(
140-
embedding_model, host, port, user, password, database, table, driver
132+
elif db_type == "MSSQL":
133+
connection_string = get("MSSQL_CONNECTION_STRING")
134+
table = get("MSSQL_TABLE")
135+
return MSSQLProvider(
136+
embedding_model, connection_string, table, embedding_length
141137
)
142138

143139
elif db_type == "QDRANT":

vector_db/mssql_provider.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import logging
2+
import re
3+
from typing import List, Optional
4+
5+
import pyodbc
6+
from langchain_core.documents import Document
7+
from langchain_sqlserver import SQLServer_VectorStore
8+
9+
from vector_db.db_provider import DBProvider
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class MSSQLProvider(DBProvider):
15+
"""
16+
SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore integration.
17+
18+
This provider connects to a Microsoft SQL Server instance using a full ODBC connection string,
19+
and stores document embeddings in a specified table. If the target database does not exist,
20+
it will be created automatically.
21+
22+
Attributes:
23+
db (SQLServer_VectorStore): Underlying LangChain-compatible vector store.
24+
connection_string (str): Full ODBC connection string to the SQL Server instance.
25+
26+
Args:
27+
embedding_model (str): HuggingFace-compatible embedding model to use.
28+
connection_string (str): Full ODBC connection string (including target DB).
29+
table (str): Table name to store vector embeddings.
30+
embedding_length (int): Dimensionality of the embeddings (e.g., 768 for all-mpnet-base-v2).
31+
32+
Example:
33+
>>> provider = MSSQLProvider(
34+
... embedding_model="BAAI/bge-large-en-v1.5",
35+
... connection_string="Driver={ODBC Driver 18 for SQL Server};Server=localhost,1433;Database=docs;UID=sa;PWD=StrongPassword!;TrustServerCertificate=yes;Encrypt=no;",
36+
... table="embedded_docs",
37+
... embedding_length=768,
38+
... )
39+
>>> provider.add_documents(docs)
40+
"""
41+
42+
def __init__(
43+
self,
44+
embedding_model: str,
45+
connection_string: str,
46+
table: str,
47+
embedding_length: int,
48+
) -> None:
49+
"""
50+
Initialize the MSSQLProvider.
51+
52+
Args:
53+
embedding_model (str): HuggingFace-compatible embedding model to use for generating embeddings.
54+
connection_string (str): Full ODBC connection string including target database name.
55+
table (str): Table name to store document embeddings.
56+
embedding_length (int): Size of the embeddings (number of dimensions).
57+
58+
Raises:
59+
RuntimeError: If the database specified in the connection string cannot be found or created.
60+
"""
61+
super().__init__(embedding_model)
62+
63+
self.connection_string = connection_string
64+
self.table = table
65+
66+
self._ensure_database_exists()
67+
68+
server = self._extract_server_address()
69+
70+
logger.info(
71+
"Connected to MSSQL instance at %s (table: %s)",
72+
server,
73+
self.table,
74+
)
75+
76+
self.db = SQLServer_VectorStore(
77+
connection_string=self.connection_string,
78+
embedding_function=self.embeddings,
79+
table_name=self.table,
80+
embedding_length=embedding_length,
81+
)
82+
83+
def _extract_server_address(self) -> str:
84+
"""
85+
Extract the server address (host,port) from the connection string.
86+
87+
Returns:
88+
str: The server address portion ("host,port") or "unknown" if not found.
89+
"""
90+
match = re.search(r"Server=([^;]+)", self.connection_string, re.IGNORECASE)
91+
return match.group(1) if match else "unknown"
92+
93+
def _extract_database_name(self) -> Optional[str]:
94+
"""
95+
Extract the database name from the connection string.
96+
97+
Returns:
98+
str: Database name if found, else None.
99+
"""
100+
match = re.search(r"Database=([^;]+)", self.connection_string, re.IGNORECASE)
101+
return match.group(1) if match else None
102+
103+
def _build_connection_string_for_master(self) -> str:
104+
"""
105+
Modify the connection string to point to the 'master' database.
106+
107+
Returns:
108+
str: Modified connection string.
109+
"""
110+
parts = self.connection_string.split(";")
111+
updated_parts = [
112+
"Database=master" if p.lower().startswith("database=") else p
113+
for p in parts
114+
if p
115+
]
116+
return ";".join(updated_parts) + ";"
117+
118+
def _ensure_database_exists(self) -> None:
119+
"""
120+
Connect to the SQL Server master database and create the target database if missing.
121+
122+
Raises:
123+
RuntimeError: If the database cannot be created or accessed.
124+
"""
125+
database = self._extract_database_name()
126+
if not database:
127+
raise RuntimeError("No database name found in connection string.")
128+
129+
master_conn_str = self._build_connection_string_for_master()
130+
try:
131+
with pyodbc.connect(master_conn_str, autocommit=True) as conn:
132+
cursor = conn.cursor()
133+
cursor.execute(
134+
f"IF DB_ID('{database}') IS NULL CREATE DATABASE [{database}]"
135+
)
136+
cursor.close()
137+
except Exception as e:
138+
logger.exception("Failed to ensure database '%s' exists", database)
139+
raise RuntimeError(f"Failed to ensure database '{database}' exists: {e}")
140+
141+
def add_documents(self, docs: List[Document]) -> None:
142+
"""
143+
Add documents to the SQL Server table in small batches.
144+
145+
Args:
146+
docs (List[Document]): LangChain document chunks to embed and store.
147+
148+
Raises:
149+
Exception: If a batch insert operation fails.
150+
"""
151+
batch_size = 50
152+
for i in range(0, len(docs), batch_size):
153+
batch = docs[i : i + batch_size]
154+
try:
155+
self.db.add_documents(batch)
156+
except Exception:
157+
logger.exception("Failed to insert batch starting at index %s", i)
158+
raise

vector_db/sqlserver_provider.py

Lines changed: 0 additions & 144 deletions
This file was deleted.

0 commit comments

Comments
 (0)