diff --git a/README.md b/README.md index 9b69e95a0..1f4cefa5c 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ All the database client supported | vespa | `pip install 'vectordb-bench[vespa]'` | | oceanbase | `pip install 'vectordb-bench[oceanbase]'` | | hologres | `pip install 'vectordb-bench[hologres]'` | +| databend | `pip install 'vectordb-bench[databend]'` | ### Run diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index 0ae328a6f..782fb8583 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -24,4 +24,5 @@ scikit-learn pymilvus clickhouse_connect pyvespa -mysql-connector-python \ No newline at end of file +mysql-connector-python +databend-driver \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a922c1fb5..aaa68a6b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ all = [ "pyvespa", "lancedb", "mysql-connector-python", + "databend-driver", ] qdrant = [ "qdrant-client" ] @@ -98,6 +99,7 @@ clickhouse = [ "clickhouse-connect" ] vespa = [ "pyvespa" ] lancedb = [ "lancedb" ] oceanbase = [ "mysql-connector-python" ] +databend = [ "databend-driver" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 79a6f964a..069b21674 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -51,6 +51,7 @@ class DB(Enum): OceanBase = "OceanBase" S3Vectors = "S3Vectors" Hologres = "Alibaba Cloud Hologres" + Databend = "Databend" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 @@ -200,6 +201,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 return Hologres + if self == DB.Databend: + from .databend.databend import Databend + + return Databend + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -351,6 +357,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 return HologresConfig + if self == DB.Databend: + from .databend.config import DatabendConfig + + return DatabendConfig + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -475,7 +486,10 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 if self == DB.Hologres: from .hologres.config import HologresIndexConfig - return HologresIndexConfig + if self == DB.Databend: + from .databend.config import DatabendIndexConfig + + return DatabendIndexConfig # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/databend/cli.py b/vectordb_bench/backend/clients/databend/cli.py new file mode 100644 index 000000000..0cb00953a --- /dev/null +++ b/vectordb_bench/backend/clients/databend/cli.py @@ -0,0 +1,47 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor2, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from .config import DatabendIndexConfig + + +class DatabendTypedDict(TypedDict): + password: Annotated[str, click.option("--password", type=str, help="DB password")] + host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)] + port: Annotated[int, click.option("--port", type=int, default=8000, help="DB Port")] + user: Annotated[int, click.option("--user", type=str, default="root", help="DB user")] + + +class DatabendHNSWTypedDict(CommonTypedDict, DatabendTypedDict, HNSWFlavor2): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(DatabendHNSWTypedDict) +def Databend(**parameters: Unpack[DatabendHNSWTypedDict]): + from .config import DatabendConfig + + run( + db=DB.Databend, + db_config=DatabendConfig( + db_label=parameters["db_label"], + user=parameters["user"], + password=SecretStr(parameters["password"]) if parameters["password"] else None, + host=parameters["host"], + port=parameters["port"], + ), + db_case_config=DatabendIndexConfig( + metric_type=None, + m=parameters["m"], + ef_construct=parameters["ef_construction"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/databend/config.py b/vectordb_bench/backend/clients/databend/config.py new file mode 100644 index 000000000..be52ba71a --- /dev/null +++ b/vectordb_bench/backend/clients/databend/config.py @@ -0,0 +1,72 @@ +from abc import abstractmethod +from typing import TypedDict + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class DatabendConfigDict(TypedDict): + user: str + password: str + host: str + port: int + database: str + secure: bool + + +class DatabendConfig(DBConfig): + user: str = "root" + password: SecretStr + host: str = "localhost" + port: int = 8000 + db_name: str = "default" + secure: bool = False + + def to_dict(self) -> DatabendConfigDict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "database": self.db_name, + "user": self.user, + "password": pwd_str, + "secure": self.secure, + } + + +class DatabendIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + create_index_before_load: bool = True + create_index_after_load: bool = False + m: int | None + ef_construct: int | None + + def parse_metric(self) -> str: + if not self.metric_type: + return "" + return self.metric_type.value + + def parse_metric_str(self) -> str: + if self.metric_type == MetricType.L2: + return "l2" + if self.metric_type == MetricType.COSINE: + return "cosine" + return "cosine" + + @abstractmethod + def session_param(self): + pass + + def index_param(self) -> dict: + return { + "m": self.m, + "metric_type": self.parse_metric_str(), + "ef_construct": self.ef_construct, + } + + def search_param(self) -> dict: + return {} + + def session_param(self) -> dict: + return {} diff --git a/vectordb_bench/backend/clients/databend/databend.py b/vectordb_bench/backend/clients/databend/databend.py new file mode 100644 index 000000000..1ae08487d --- /dev/null +++ b/vectordb_bench/backend/clients/databend/databend.py @@ -0,0 +1,218 @@ +"""Wrapper around the Databend vector database over VectorDB""" + +import logging +from contextlib import contextmanager +from typing import Any + +from databend_driver import ( + BlockingDatabendClient, + BlockingDatabendConnection, +) + +from .. import IndexType +from ..api import VectorDB +from .config import DatabendConfigDict, DatabendIndexConfig + +log = logging.getLogger(__name__) + + +class Databend(VectorDB): + """Use SQLAlchemy instructions""" + + def __init__( + self, + dim: int, + db_config: DatabendConfigDict, + db_case_config: DatabendIndexConfig, + collection_name: str = "DatabendVectorCollection", + drop_old: bool = False, + **kwargs, + ): + self.db_config = db_config + self.case_config = db_case_config + self.database_name = db_config['database'] + self.table_name = collection_name + self.dim = dim + + self.index_param = self.case_config.index_param() + self.search_param = self.case_config.search_param() + self.session_param = self.case_config.session_param() + + self._index_name = "databend_index" + self._primary_field = "id" + self._vector_field = "embedding" + + # construct basic units + self.conn = self._create_connection(**self.db_config, settings=self.session_param) + + if drop_old: + log.info(f"Databend client drop table : {self.table_name}") + self._drop_table() + self._create_table(dim) + if self.case_config.create_index_before_load: + self._create_index() + + self.conn.close() + self.conn = None + + @contextmanager + def init(self) -> None: + """ + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + >>> self.search_embedding() + """ + + self.conn = self._create_connection(**self.db_config, settings=self.session_param) + + try: + yield + finally: + self.conn.close() + self.conn = None + + def _create_connection(self, settings: dict | None, **kwargs) -> BlockingDatabendConnection: + databend_client = BlockingDatabendClient( + #f"databend://{self.db_config.user}:{self.db_config.password}@{self.db_config.host}:{self.db_config.port}/{self.db_config.db_name}?sslmode=disable" + f'databend://{self.db_config["user"]}:{self.db_config["password"]}@{self.db_config["host"]}:{self.db_config["port"]}/{self.db_config["database"]}?sslmode=disable' + ) + return databend_client.get_conn() + + def _drop_index(self): + assert self.conn is not None, "Connection is not initialized" + try: + self.conn.exec( + f"DROP VECTOR INDEX IF EXISTS {self._index_name} ON {self.database_name}.{self.table_name}" + ) + except Exception as e: + log.warning(f"Failed to drop index on table {self.database_name}.{self.table_name}: {e}") + raise e from None + + def _drop_table(self): + assert self.conn is not None, "Connection is not initialized" + + try: + self.conn.exec(f"DROP TABLE IF EXISTS {self.database_name}.{self.table_name}") + except Exception as e: + log.warning(f"Failed to drop table {self.database_name}.{self.table_name}: {e}") + raise e from None + + def _perfomance_tuning(self): + pass + + def _create_index(self): + assert self.conn is not None, "Connection is not initialized" + try: + self.conn.exec( + f"CREATE VECTOR INDEX IF NOT EXISTS {self._index_name} " + f"ON {self.database_name}.{self.table_name} " + f'({self._vector_field}) m = {self.index_param["m"]} ' + f'ef_construct = {self.index_param["ef_construct"]} ' + f'distance = {self.index_param["metric_type"]}' + ) + except Exception as e: + log.warning(f"Failed to create Databend vector index on table: {self.table_name} error: {e}") + raise e from None + + def _create_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + + try: + # create table + self.conn.exec( + f"CREATE TABLE IF NOT EXISTS {self.database_name}.{self.table_name} " + f"({self._primary_field} UInt32, " + f"{self._vector_field} Vector({self.dim})) " + f"ENGINE = Fuse" + ) + + except Exception as e: + log.warning(f"Failed to create Databend table: {self.table_name} error: {e}") + raise e from None + + def optimize(self, data_size: int | None = None): + assert self.conn is not None, "Connection is not initialized" + + try: + self.conn.exec(f"OPTIMIZE TABLE {self.database_name}.{self.table_name} ALL") + + except Exception as e: + log.warning(f"Failed to optimize Databend table: {self.table_name} error: {e}") + raise e from None + + def _post_insert(self): + pass + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> (int, Exception): + assert self.conn is not None, "Connection is not initialized" + + try: + rows: List[List[Any]] = [] + for _id, embedding in zip(metadata, embeddings): + row: List[Any] = [ + str(_id), + str(embedding), + ] + rows.append(row) + + self.conn.stream_load( + f"INSERT INTO {self.database_name}.{self.table_name} VALUES", + rows, + ) + + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into Databend table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + if self.case_config.metric_type == "COSINE": + if filters: + _id = filters.get("id", 0) + result = self.conn.query_all( + f"SELECT {self._primary_field} " + f"FROM {self.database_name}.{self.table_name} " + f"WHERE {self._primary_field} > {_id} " + f"ORDER BY cosine_distance({self._vector_field}, {query}::Vector({self.dim})) " + f"LIMIT {k}", + ) + return [int(row.values()[0]) for row in result] + + result = self.conn.query_all( + f"SELECT {self._primary_field} " + f"FROM {self.database_name}.{self.table_name} " + f"ORDER BY cosine_distance({self._vector_field}, {query}::Vector({self.dim})) " + f"LIMIT {k}", + ) + return [int(row.values()[0]) for row in result] + if filters: + _id = filters.get("id", 0) + result = self.conn.query_all( + f"SELECT {self._primary_field} " + f"FROM {self.database_name}.{self.table_name} " + f"WHERE {self._primary_field} > {_id} " + f"ORDER BY l2_distance({self._vector_field}, {query}::Vector({self.dim})) " + f"LIMIT {k}", + ) + return [int(row.values()[0]) for row in result] + + result = self.conn.query_all( + f"SELECT {self._primary_field} " + f"FROM {self.database_name}.{self.table_name} " + f"ORDER BY l2_distance({self._vector_field}, {query}::Vector({self.dim})) " + f"LIMIT {k}", + ) + return [int(row.values()[0]) for row in result] diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 83dab74f6..ea1a8a2c3 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -21,6 +21,7 @@ from ..backend.clients.vespa.cli import Vespa from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex +from ..backend.clients.databend.cli import Databend from .batch_cli import BatchCli from .cli import cli @@ -50,6 +51,7 @@ cli.add_command(QdrantLocal) cli.add_command(BatchCli) cli.add_command(S3Vectors) +cli.add_command(Databend) if __name__ == "__main__":