|
| 1 | +import logging |
| 2 | +import re |
| 3 | +from typing import List, Optional |
| 4 | + |
| 5 | +import pyodbc |
| 6 | +from langchain_core.documents import Document |
| 7 | +from langchain_sqlserver import SQLServer_VectorStore |
| 8 | + |
| 9 | +from vector_db.db_provider import DBProvider |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +class MSSQLProvider(DBProvider): |
| 15 | + """ |
| 16 | + SQL Server-based vector DB provider using LangChain's SQLServer_VectorStore integration. |
| 17 | +
|
| 18 | + This provider connects to a Microsoft SQL Server instance using a full ODBC connection string, |
| 19 | + and stores document embeddings in a specified table. If the target database does not exist, |
| 20 | + it will be created automatically. |
| 21 | +
|
| 22 | + Attributes: |
| 23 | + db (SQLServer_VectorStore): Underlying LangChain-compatible vector store. |
| 24 | + connection_string (str): Full ODBC connection string to the SQL Server instance. |
| 25 | +
|
| 26 | + Args: |
| 27 | + embedding_model (str): HuggingFace-compatible embedding model to use. |
| 28 | + connection_string (str): Full ODBC connection string (including target DB). |
| 29 | + table (str): Table name to store vector embeddings. |
| 30 | + embedding_length (int): Dimensionality of the embeddings (e.g., 768 for all-mpnet-base-v2). |
| 31 | +
|
| 32 | + Example: |
| 33 | + >>> provider = MSSQLProvider( |
| 34 | + ... embedding_model="BAAI/bge-large-en-v1.5", |
| 35 | + ... connection_string="Driver={ODBC Driver 18 for SQL Server};Server=localhost,1433;Database=docs;UID=sa;PWD=StrongPassword!;TrustServerCertificate=yes;Encrypt=no;", |
| 36 | + ... table="embedded_docs", |
| 37 | + ... embedding_length=768, |
| 38 | + ... ) |
| 39 | + >>> provider.add_documents(docs) |
| 40 | + """ |
| 41 | + |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + embedding_model: str, |
| 45 | + connection_string: str, |
| 46 | + table: str, |
| 47 | + embedding_length: int, |
| 48 | + ) -> None: |
| 49 | + """ |
| 50 | + Initialize the MSSQLProvider. |
| 51 | +
|
| 52 | + Args: |
| 53 | + embedding_model (str): HuggingFace-compatible embedding model to use for generating embeddings. |
| 54 | + connection_string (str): Full ODBC connection string including target database name. |
| 55 | + table (str): Table name to store document embeddings. |
| 56 | + embedding_length (int): Size of the embeddings (number of dimensions). |
| 57 | +
|
| 58 | + Raises: |
| 59 | + RuntimeError: If the database specified in the connection string cannot be found or created. |
| 60 | + """ |
| 61 | + super().__init__(embedding_model) |
| 62 | + |
| 63 | + self.connection_string = connection_string |
| 64 | + self.table = table |
| 65 | + |
| 66 | + self._ensure_database_exists() |
| 67 | + |
| 68 | + server = self._extract_server_address() |
| 69 | + |
| 70 | + logger.info( |
| 71 | + "Connected to MSSQL instance at %s (table: %s)", |
| 72 | + server, |
| 73 | + self.table, |
| 74 | + ) |
| 75 | + |
| 76 | + self.db = SQLServer_VectorStore( |
| 77 | + connection_string=self.connection_string, |
| 78 | + embedding_function=self.embeddings, |
| 79 | + table_name=self.table, |
| 80 | + embedding_length=embedding_length, |
| 81 | + ) |
| 82 | + |
| 83 | + def _extract_server_address(self) -> str: |
| 84 | + """ |
| 85 | + Extract the server address (host,port) from the connection string. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + str: The server address portion ("host,port") or "unknown" if not found. |
| 89 | + """ |
| 90 | + match = re.search(r"Server=([^;]+)", self.connection_string, re.IGNORECASE) |
| 91 | + return match.group(1) if match else "unknown" |
| 92 | + |
| 93 | + def _extract_database_name(self) -> Optional[str]: |
| 94 | + """ |
| 95 | + Extract the database name from the connection string. |
| 96 | +
|
| 97 | + Returns: |
| 98 | + str: Database name if found, else None. |
| 99 | + """ |
| 100 | + match = re.search(r"Database=([^;]+)", self.connection_string, re.IGNORECASE) |
| 101 | + return match.group(1) if match else None |
| 102 | + |
| 103 | + def _build_connection_string_for_master(self) -> str: |
| 104 | + """ |
| 105 | + Modify the connection string to point to the 'master' database. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + str: Modified connection string. |
| 109 | + """ |
| 110 | + parts = self.connection_string.split(";") |
| 111 | + updated_parts = [ |
| 112 | + "Database=master" if p.lower().startswith("database=") else p |
| 113 | + for p in parts |
| 114 | + if p |
| 115 | + ] |
| 116 | + return ";".join(updated_parts) + ";" |
| 117 | + |
| 118 | + def _ensure_database_exists(self) -> None: |
| 119 | + """ |
| 120 | + Connect to the SQL Server master database and create the target database if missing. |
| 121 | +
|
| 122 | + Raises: |
| 123 | + RuntimeError: If the database cannot be created or accessed. |
| 124 | + """ |
| 125 | + database = self._extract_database_name() |
| 126 | + if not database: |
| 127 | + raise RuntimeError("No database name found in connection string.") |
| 128 | + |
| 129 | + master_conn_str = self._build_connection_string_for_master() |
| 130 | + try: |
| 131 | + with pyodbc.connect(master_conn_str, autocommit=True) as conn: |
| 132 | + cursor = conn.cursor() |
| 133 | + cursor.execute( |
| 134 | + f"IF DB_ID('{database}') IS NULL CREATE DATABASE [{database}]" |
| 135 | + ) |
| 136 | + cursor.close() |
| 137 | + except Exception as e: |
| 138 | + logger.exception("Failed to ensure database '%s' exists", database) |
| 139 | + raise RuntimeError(f"Failed to ensure database '{database}' exists: {e}") |
| 140 | + |
| 141 | + def add_documents(self, docs: List[Document]) -> None: |
| 142 | + """ |
| 143 | + Add documents to the SQL Server table in small batches. |
| 144 | +
|
| 145 | + Args: |
| 146 | + docs (List[Document]): LangChain document chunks to embed and store. |
| 147 | +
|
| 148 | + Raises: |
| 149 | + Exception: If a batch insert operation fails. |
| 150 | + """ |
| 151 | + batch_size = 50 |
| 152 | + for i in range(0, len(docs), batch_size): |
| 153 | + batch = docs[i : i + batch_size] |
| 154 | + try: |
| 155 | + self.db.add_documents(batch) |
| 156 | + except Exception: |
| 157 | + logger.exception("Failed to insert batch starting at index %s", i) |
| 158 | + raise |
0 commit comments