Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ All the database client supported
| awsopensearch | `pip install vectordb-bench[opensearch]` |
| aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` |
| mongodb | `pip install vectordb-bench[mongodb]` |
| astradb | `pip install vectordb-bench[astradb]` |
| tidb | `pip install vectordb-bench[tidb]` |
| vespa | `pip install vectordb-bench[vespa]` |
| oceanbase | `pip install vectordb-bench[oceanbase]` |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ chromadb = [ "chromadb" ]
opensearch = [ "opensearch-py" ]
aliyun_opensearch = [ "alibabacloud_ha3engine_vector" ]
mongodb = [ "pymongo" ]
astradb = [ "astrapy" ]
mariadb = [ "mariadb" ]
tidb = [ "PyMySQL" ]
cockroachdb = [ "psycopg[binary,pool]", "pgvector" ]
Expand Down
16 changes: 16 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DB(Enum):
Test = "test"
AliyunOpenSearch = "AliyunOpenSearch"
MongoDB = "MongoDB"
AstraDB = "AstraDB"
TiDB = "TiDB"
CockroachDB = "CockroachDB"
Clickhouse = "Clickhouse"
Expand Down Expand Up @@ -165,6 +166,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915

return MongoDB

if self == DB.AstraDB:
from .astradb.astradb import AstraDB

return AstraDB

if self == DB.OceanBase:
from .oceanbase.oceanbase import OceanBase

Expand Down Expand Up @@ -339,6 +345,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915

return MongoDBConfig

if self == DB.AstraDB:
from .astradb.config import AstraDBConfig

return AstraDBConfig

if self == DB.OceanBase:
from .oceanbase.config import OceanBaseConfig

Expand Down Expand Up @@ -494,6 +505,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915

return MongoDBIndexConfig

if self == DB.AstraDB:
from .astradb.config import AstraDBIndexConfig

return AstraDBIndexConfig

if self == DB.OceanBase:
from .oceanbase.config import _oceanbase_case_config

Expand Down
169 changes: 169 additions & 0 deletions vectordb_bench/backend/clients/astradb/astradb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import logging
import time
from contextlib import contextmanager

from astrapy import DataAPIClient
from astrapy.constants import VectorMetric
from astrapy.info import CollectionDefinition

from ..api import VectorDB
from .config import AstraDBIndexConfig

log = logging.getLogger(__name__)


class AstraDBError(Exception):
"""Custom exception class for AstraDB client errors."""


class AstraDB(VectorDB):
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: AstraDBIndexConfig,
collection_name: str = "vdb_bench_collection",
id_field: str = "id",
vector_field: str = "vector",
drop_old: bool = False,
**kwargs,
):
self.dim = dim
self.db_config = db_config
self.case_config = db_case_config
self.collection_name = collection_name
self.id_field = id_field
self.vector_field = vector_field
self.drop_old = drop_old

# Get index parameters
index_params = self.case_config.index_param()
log.info(f"index params: {index_params}")
self.index_params = index_params

# Initialize client - will be properly set in init()
self.client = None
self.db = None
self.collection = None

# Initialize and drop collection if needed
temp_client = DataAPIClient(self.db_config["token"])
temp_db = temp_client.get_database(
api_endpoint=self.db_config["api_endpoint"],
keyspace=self.db_config["namespace"]
)

if self.drop_old:
try:
temp_db.drop_collection(self.collection_name)
log.info(f"AstraDB client dropped old collection: {self.collection_name}")
except Exception:
log.info(f"Collection {self.collection_name} does not exist, skipping drop")

@contextmanager
def init(self):
"""Initialize AstraDB client and cleanup when done"""
try:
self.client = DataAPIClient(self.db_config["token"])
self.db = self.client.get_database(
api_endpoint=self.db_config["api_endpoint"],
keyspace=self.db_config["namespace"]
)

# Create or get collection with vector configuration
metric_str = self.case_config.parse_metric()

# Map metric string to VectorMetric constant
metric_map = {
"euclidean": VectorMetric.EUCLIDEAN,
"dot_product": VectorMetric.DOT_PRODUCT,
"cosine": VectorMetric.COSINE,
}
metric = metric_map.get(metric_str, VectorMetric.COSINE)

# Create collection with new API
# Note: check_exists is no longer needed - API handles conflicts automatically
self.collection = self.db.create_collection(
name=self.collection_name,
definition=(
CollectionDefinition.builder()
.set_vector_dimension(self.dim)
.set_vector_metric(metric)
.build()
),
)
log.info(f"Created/accessed collection: {self.collection_name} with metric: {metric_str}")

yield
finally:
if self.client is not None:
self.client = None
self.db = None
self.collection = None

def need_normalize_cosine(self) -> bool:
return False

def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
**kwargs,
) -> (int, Exception | None):
"""Insert embeddings into AstraDB"""

# Prepare documents in bulk
documents = [
{
"_id": str(id_),
"$vector": embedding,
}
for id_, embedding in zip(metadata, embeddings, strict=False)
]

# Insert documents in batches
try:
result = self.collection.insert_many(documents, ordered=False)
return len(result.inserted_ids), None
except Exception as e:
log.exception("Error inserting embeddings")
return 0, e

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
**kwargs,
) -> list[int]:
"""Search for similar vectors"""

# Build filter if specified
search_filter = None
if filters:
log.info(f"Applying filter: {filters}")
search_filter = {
self.id_field: {"$gte": filters["id"]},
}

# Perform vector search
try:
results = self.collection.find(
filter=search_filter,
sort={"$vector": query},
limit=k,
include_similarity=True,
)

# Extract IDs from results
return [int(doc["_id"]) for doc in results]
except Exception:
log.exception("Error searching embeddings")
return []

def optimize(self, data_size: int | None = None) -> None:
"""AstraDB vector indexes are automatically managed"""
log.info("optimize for search - AstraDB manages indexes automatically")

def ready_to_load(self) -> None:
"""AstraDB is always ready to load"""
86 changes: 86 additions & 0 deletions vectordb_bench/backend/clients/astradb/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Annotated, TypedDict, Unpack

import click
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from .. import DB
from ..api import MetricType
from .config import AstraDBIndexConfig


class AstraDBTypedDict(TypedDict):
api_endpoint: Annotated[
str,
click.option(
"--api-endpoint",
type=str,
help="AstraDB API endpoint (e.g., https://<database-id>-<region>.apps.astra.datastax.com)",
required=True,
),
]
token: Annotated[
str,
click.option(
"--token",
type=str,
help="AstraDB authentication token",
required=True,
),
]
namespace: Annotated[
str,
click.option(
"--namespace",
type=str,
help="AstraDB namespace (keyspace)",
default="default_keyspace",
show_default=True,
),
]
metric: Annotated[
str,
click.option(
"--metric",
type=click.Choice(["cosine", "euclidean", "dot_product"], case_sensitive=False),
help="Distance metric for vector similarity",
default="cosine",
show_default=True,
),
]


class AstraDBIndexTypedDict(CommonTypedDict, AstraDBTypedDict): ...


@cli.command()
@click_parameter_decorators_from_typed_dict(AstraDBIndexTypedDict)
def AstraDB(**parameters: Unpack[AstraDBIndexTypedDict]):
from .config import AstraDBConfig

# Convert metric string to MetricType enum
metric_map = {
"cosine": MetricType.COSINE,
"euclidean": MetricType.L2,
"dot_product": MetricType.IP,
}
metric_type = metric_map.get(parameters["metric"].lower(), MetricType.COSINE)

run(
db=DB.AstraDB,
db_config=AstraDBConfig(
db_label=parameters["db_label"],
api_endpoint=parameters["api_endpoint"],
token=SecretStr(parameters["token"]),
namespace=parameters["namespace"],
),
db_case_config=AstraDBIndexConfig(
metric_type=metric_type,
),
**parameters,
)
38 changes: 38 additions & 0 deletions vectordb_bench/backend/clients/astradb/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from enum import Enum

from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType


class AstraDBConfig(DBConfig, BaseModel):
api_endpoint: str = "https://<database-id>-<region>.apps.astra.datastax.com"
token: SecretStr = "<your-astra-token>"
namespace: str = "default_keyspace"

def to_dict(self) -> dict:
return {
"api_endpoint": self.api_endpoint,
"token": self.token.get_secret_value(),
"namespace": self.namespace,
}


class AstraDBIndexConfig(BaseModel, DBCaseConfig):
index: IndexType = IndexType.HNSW # AstraDB uses vector search
metric_type: MetricType = MetricType.COSINE

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "euclidean"
if self.metric_type == MetricType.IP:
return "dot_product"
return "cosine" # Default to cosine similarity

def index_param(self) -> dict:
return {
"metric": self.parse_metric(),
}

def search_param(self) -> dict:
return {}
2 changes: 2 additions & 0 deletions vectordb_bench/cli/vectordbbench.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..backend.clients.alisql.cli import AliSQLHNSW
from ..backend.clients.alloydb.cli import AlloyDBScaNN
from ..backend.clients.astradb.cli import AstraDB
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
from ..backend.clients.chroma.cli import Chroma
from ..backend.clients.clickhouse.cli import Clickhouse
Expand Down Expand Up @@ -50,6 +51,7 @@
cli.add_command(PgVectorScaleDiskAnn)
cli.add_command(PgDiskAnn)
cli.add_command(AlloyDBScaNN)
cli.add_command(AstraDB)
cli.add_command(OceanBaseHNSW)
cli.add_command(OceanBaseIVF)
cli.add_command(MariaDBHNSW)
Expand Down
7 changes: 7 additions & 0 deletions vectordb_bench/frontend/config/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,9 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_MongoDBNumCandidatesRatio,
]

AstraDBLoadingConfig = []
AstraDBPerformanceConfig = []

CockroachDBLoadingConfig = [
CaseConfigParamInput_IndexType_CockroachDB,
CaseConfigParamInput_MinPartitionSize_CockroachDB,
Expand Down Expand Up @@ -2691,6 +2694,10 @@ class CaseConfigInput(BaseModel):
CaseLabel.Load: MongoDBLoadingConfig,
CaseLabel.Performance: MongoDBPerformanceConfig,
},
DB.AstraDB: {
CaseLabel.Load: AstraDBLoadingConfig,
CaseLabel.Performance: AstraDBPerformanceConfig,
},
DB.MariaDB: {
CaseLabel.Load: MariaDBLoadingConfig,
CaseLabel.Performance: MariaDBPerformanceConfig,
Expand Down
Loading
Loading