diff --git a/README.md b/README.md index bd7568da1..64701f353 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ All the database client supported | awsopensearch | `pip install vectordb-bench[opensearch]` | | aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` | | mongodb | `pip install vectordb-bench[mongodb]` | +| astradb | `pip install vectordb-bench[astradb]` | | tidb | `pip install vectordb-bench[tidb]` | | vespa | `pip install vectordb-bench[vespa]` | | oceanbase | `pip install vectordb-bench[oceanbase]` | diff --git a/pyproject.toml b/pyproject.toml index 63c585c55..281f8d543 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ chromadb = [ "chromadb" ] opensearch = [ "opensearch-py" ] aliyun_opensearch = [ "alibabacloud_ha3engine_vector" ] mongodb = [ "pymongo" ] +astradb = [ "astrapy" ] mariadb = [ "mariadb" ] tidb = [ "PyMySQL" ] cockroachdb = [ "psycopg[binary,pool]", "pgvector" ] diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index d69c54504..6a3b52d9f 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -44,6 +44,7 @@ class DB(Enum): Test = "test" AliyunOpenSearch = "AliyunOpenSearch" MongoDB = "MongoDB" + AstraDB = "AstraDB" TiDB = "TiDB" CockroachDB = "CockroachDB" Clickhouse = "Clickhouse" @@ -165,6 +166,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 return MongoDB + if self == DB.AstraDB: + from .astradb.astradb import AstraDB + + return AstraDB + if self == DB.OceanBase: from .oceanbase.oceanbase import OceanBase @@ -339,6 +345,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 return MongoDBConfig + if self == DB.AstraDB: + from .astradb.config import AstraDBConfig + + return AstraDBConfig + if self == DB.OceanBase: from .oceanbase.config import OceanBaseConfig @@ -494,6 +505,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915 return MongoDBIndexConfig + if self == DB.AstraDB: + from .astradb.config import AstraDBIndexConfig + + return AstraDBIndexConfig + if self == DB.OceanBase: from .oceanbase.config import _oceanbase_case_config diff --git a/vectordb_bench/backend/clients/astradb/astradb.py b/vectordb_bench/backend/clients/astradb/astradb.py new file mode 100644 index 000000000..df7eb0049 --- /dev/null +++ b/vectordb_bench/backend/clients/astradb/astradb.py @@ -0,0 +1,169 @@ +import logging +import time +from contextlib import contextmanager + +from astrapy import DataAPIClient +from astrapy.constants import VectorMetric +from astrapy.info import CollectionDefinition + +from ..api import VectorDB +from .config import AstraDBIndexConfig + +log = logging.getLogger(__name__) + + +class AstraDBError(Exception): + """Custom exception class for AstraDB client errors.""" + + +class AstraDB(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: AstraDBIndexConfig, + collection_name: str = "vdb_bench_collection", + id_field: str = "id", + vector_field: str = "vector", + drop_old: bool = False, + **kwargs, + ): + self.dim = dim + self.db_config = db_config + self.case_config = db_case_config + self.collection_name = collection_name + self.id_field = id_field + self.vector_field = vector_field + self.drop_old = drop_old + + # Get index parameters + index_params = self.case_config.index_param() + log.info(f"index params: {index_params}") + self.index_params = index_params + + # Initialize client - will be properly set in init() + self.client = None + self.db = None + self.collection = None + + # Initialize and drop collection if needed + temp_client = DataAPIClient(self.db_config["token"]) + temp_db = temp_client.get_database( + api_endpoint=self.db_config["api_endpoint"], + keyspace=self.db_config["namespace"] + ) + + if self.drop_old: + try: + temp_db.drop_collection(self.collection_name) + log.info(f"AstraDB client dropped old collection: {self.collection_name}") + except Exception: + log.info(f"Collection {self.collection_name} does not exist, skipping drop") + + @contextmanager + def init(self): + """Initialize AstraDB client and cleanup when done""" + try: + self.client = DataAPIClient(self.db_config["token"]) + self.db = self.client.get_database( + api_endpoint=self.db_config["api_endpoint"], + keyspace=self.db_config["namespace"] + ) + + # Create or get collection with vector configuration + metric_str = self.case_config.parse_metric() + + # Map metric string to VectorMetric constant + metric_map = { + "euclidean": VectorMetric.EUCLIDEAN, + "dot_product": VectorMetric.DOT_PRODUCT, + "cosine": VectorMetric.COSINE, + } + metric = metric_map.get(metric_str, VectorMetric.COSINE) + + # Create collection with new API + # Note: check_exists is no longer needed - API handles conflicts automatically + self.collection = self.db.create_collection( + name=self.collection_name, + definition=( + CollectionDefinition.builder() + .set_vector_dimension(self.dim) + .set_vector_metric(metric) + .build() + ), + ) + log.info(f"Created/accessed collection: {self.collection_name} with metric: {metric_str}") + + yield + finally: + if self.client is not None: + self.client = None + self.db = None + self.collection = None + + def need_normalize_cosine(self) -> bool: + return False + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> (int, Exception | None): + """Insert embeddings into AstraDB""" + + # Prepare documents in bulk + documents = [ + { + "_id": str(id_), + "$vector": embedding, + } + for id_, embedding in zip(metadata, embeddings, strict=False) + ] + + # Insert documents in batches + try: + result = self.collection.insert_many(documents, ordered=False) + return len(result.inserted_ids), None + except Exception as e: + log.exception("Error inserting embeddings") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + **kwargs, + ) -> list[int]: + """Search for similar vectors""" + + # Build filter if specified + search_filter = None + if filters: + log.info(f"Applying filter: {filters}") + search_filter = { + self.id_field: {"$gte": filters["id"]}, + } + + # Perform vector search + try: + results = self.collection.find( + filter=search_filter, + sort={"$vector": query}, + limit=k, + include_similarity=True, + ) + + # Extract IDs from results + return [int(doc["_id"]) for doc in results] + except Exception: + log.exception("Error searching embeddings") + return [] + + def optimize(self, data_size: int | None = None) -> None: + """AstraDB vector indexes are automatically managed""" + log.info("optimize for search - AstraDB manages indexes automatically") + + def ready_to_load(self) -> None: + """AstraDB is always ready to load""" diff --git a/vectordb_bench/backend/clients/astradb/cli.py b/vectordb_bench/backend/clients/astradb/cli.py new file mode 100644 index 000000000..01ea6a31f --- /dev/null +++ b/vectordb_bench/backend/clients/astradb/cli.py @@ -0,0 +1,86 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from ..api import MetricType +from .config import AstraDBIndexConfig + + +class AstraDBTypedDict(TypedDict): + api_endpoint: Annotated[ + str, + click.option( + "--api-endpoint", + type=str, + help="AstraDB API endpoint (e.g., https://-.apps.astra.datastax.com)", + required=True, + ), + ] + token: Annotated[ + str, + click.option( + "--token", + type=str, + help="AstraDB authentication token", + required=True, + ), + ] + namespace: Annotated[ + str, + click.option( + "--namespace", + type=str, + help="AstraDB namespace (keyspace)", + default="default_keyspace", + show_default=True, + ), + ] + metric: Annotated[ + str, + click.option( + "--metric", + type=click.Choice(["cosine", "euclidean", "dot_product"], case_sensitive=False), + help="Distance metric for vector similarity", + default="cosine", + show_default=True, + ), + ] + + +class AstraDBIndexTypedDict(CommonTypedDict, AstraDBTypedDict): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(AstraDBIndexTypedDict) +def AstraDB(**parameters: Unpack[AstraDBIndexTypedDict]): + from .config import AstraDBConfig + + # Convert metric string to MetricType enum + metric_map = { + "cosine": MetricType.COSINE, + "euclidean": MetricType.L2, + "dot_product": MetricType.IP, + } + metric_type = metric_map.get(parameters["metric"].lower(), MetricType.COSINE) + + run( + db=DB.AstraDB, + db_config=AstraDBConfig( + db_label=parameters["db_label"], + api_endpoint=parameters["api_endpoint"], + token=SecretStr(parameters["token"]), + namespace=parameters["namespace"], + ), + db_case_config=AstraDBIndexConfig( + metric_type=metric_type, + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/astradb/config.py b/vectordb_bench/backend/clients/astradb/config.py new file mode 100644 index 000000000..01efa6dd9 --- /dev/null +++ b/vectordb_bench/backend/clients/astradb/config.py @@ -0,0 +1,38 @@ +from enum import Enum + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class AstraDBConfig(DBConfig, BaseModel): + api_endpoint: str = "https://-.apps.astra.datastax.com" + token: SecretStr = "" + namespace: str = "default_keyspace" + + def to_dict(self) -> dict: + return { + "api_endpoint": self.api_endpoint, + "token": self.token.get_secret_value(), + "namespace": self.namespace, + } + + +class AstraDBIndexConfig(BaseModel, DBCaseConfig): + index: IndexType = IndexType.HNSW # AstraDB uses vector search + metric_type: MetricType = MetricType.COSINE + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + if self.metric_type == MetricType.IP: + return "dot_product" + return "cosine" # Default to cosine similarity + + def index_param(self) -> dict: + return { + "metric": self.parse_metric(), + } + + def search_param(self) -> dict: + return {} diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 76e9534f9..4cd0a2b5e 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,5 +1,6 @@ from ..backend.clients.alisql.cli import AliSQLHNSW from ..backend.clients.alloydb.cli import AlloyDBScaNN +from ..backend.clients.astradb.cli import AstraDB from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.chroma.cli import Chroma from ..backend.clients.clickhouse.cli import Clickhouse @@ -50,6 +51,7 @@ cli.add_command(PgVectorScaleDiskAnn) cli.add_command(PgDiskAnn) cli.add_command(AlloyDBScaNN) +cli.add_command(AstraDB) cli.add_command(OceanBaseHNSW) cli.add_command(OceanBaseIVF) cli.add_command(MariaDBHNSW) diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 6a32e5ff1..c46ba0cef 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -2373,6 +2373,9 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_MongoDBNumCandidatesRatio, ] +AstraDBLoadingConfig = [] +AstraDBPerformanceConfig = [] + CockroachDBLoadingConfig = [ CaseConfigParamInput_IndexType_CockroachDB, CaseConfigParamInput_MinPartitionSize_CockroachDB, @@ -2691,6 +2694,10 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: MongoDBLoadingConfig, CaseLabel.Performance: MongoDBPerformanceConfig, }, + DB.AstraDB: { + CaseLabel.Load: AstraDBLoadingConfig, + CaseLabel.Performance: AstraDBPerformanceConfig, + }, DB.MariaDB: { CaseLabel.Load: MariaDBLoadingConfig, CaseLabel.Performance: MariaDBPerformanceConfig, diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index bce4561fd..8cd0f3c9f 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -60,6 +60,7 @@ def getPatternShape(i): DB.AliyunOpenSearch: "", DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", DB.OSSOpenSearch: "https://images.seeklogo.com/logo-png/50/1/opensearch-icon-logo-png_seeklogo-500356.png", + DB.AstraDB: "", DB.MongoDB: "", DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", DB.Clickhouse: "",