Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions install/requirements_py3.11.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
grpcio==1.53.2
grpcio-tools==1.53.0
grpcio>=1.53.2
grpcio-tools>=1.53.0
qdrant-client
pinecone-client
weaviate-client
Expand All @@ -19,11 +19,13 @@ psutil
polars
plotly
environs
pydantic<v2
pydantic>=v2
scikit-learn
pymilvus
clickhouse_connect
pyvespa
mysql-connector-python
packaging
hdrhistogram>=0.10.1
ujson
numpy
hdrhistogram>=0.10.1
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"click",
"pytz",
"streamlit-autorefresh",
"streamlit<1.44,!=1.34.0", # There is a breaking change in 1.44 related to get_page https://discuss.streamlit.io/t/from-streamlit-source-util-import-get-pages-gone-in-v-1-44-0-need-urgent-help/98399
"streamlit>=1.52.2", # There is a breaking change in 1.44 related to get_page https://discuss.streamlit.io/t/from-streamlit-source-util-import-get-pages-gone-in-v-1-44-0-need-urgent-help/98399
"streamlit_extras",
"tqdm",
"s3fs",
Expand All @@ -37,7 +37,7 @@ dependencies = [
"polars",
"plotly",
"environs",
"pydantic<v2",
"pydantic>=2",
"scikit-learn",
"pymilvus", # with pandas, numpy, ujson
"ujson",
Expand All @@ -54,11 +54,11 @@ test = [
restful = [ "flask" ]

all = [
"grpcio==1.53.0", # for qdrant-client and pymilvus
"grpcio-tools==1.53.0", # for qdrant-client and pymilvus
"grpcio>=1.53.0", # for qdrant-client, pymilvus and weaviate
"grpcio-tools>=1.53.0", # for qdrant-client, pymilvus and weaviate
"qdrant-client",
"pinecone-client",
"weaviate-client",
"weaviate-client>=4.18.3",
"elasticsearch",
"sqlalchemy",
"redis",
Expand Down
19 changes: 10 additions & 9 deletions tests/test_bench_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import logging
from vectordb_bench.interface import BenchMarkRunner
from vectordb_bench.models import (
DB, IndexType, CaseType, TaskConfig, CaseConfig,
DB, CaseType, TaskConfig, CaseConfig,
)
from vectordb_bench.backend.clients.api import IndexType

log = logging.getLogger(__name__)

Expand All @@ -19,9 +20,9 @@ def test_performance_case_whole(self):

task_config=TaskConfig(
db=DB.Milvus,
db_config=DB.Milvus.config(),
db_case_config=DB.Milvus.case_config_cls(index=IndexType.Flat)(),
case_config=CaseConfig(case_id=CaseType.PerformanceSZero),
db_config=DB.Milvus.config_cls(),
db_case_config=DB.Milvus.case_config_cls(index_type=IndexType.Flat)(),
case_config=CaseConfig(case_id=CaseType.Performance768D1M),
)

runner.run([task_config])
Expand All @@ -34,9 +35,9 @@ def test_performance_case_clean(self):

task_config=TaskConfig(
db=DB.Milvus,
db_config=DB.Milvus.config(),
db_case_config=DB.Milvus.case_config_cls(index=IndexType.Flat)(),
case_config=CaseConfig(case_id=CaseType.PerformanceSZero),
db_config=DB.Milvus.config_cls(),
db_case_config=DB.Milvus.case_config_cls(index_type=IndexType.Flat)(),
case_config=CaseConfig(case_id=CaseType.Performance768D1M),
)

runner.run([task_config])
Expand All @@ -46,9 +47,9 @@ def test_performance_case_clean(self):
def test_performance_case_no_error(self):
task_config=TaskConfig(
db=DB.ZillizCloud,
db_config=DB.ZillizCloud.config(uri="xxx", user="abc", password="1234"),
db_config=DB.ZillizCloud.config_cls(uri="xxx", user="abc", password="1234"),
db_case_config=DB.ZillizCloud.case_config_cls()(),
case_config=CaseConfig(case_id=CaseType.PerformanceSZero),
case_config=CaseConfig(case_id=CaseType.Performance768D1M),
)

t = task_config.copy()
Expand Down
16 changes: 8 additions & 8 deletions tests/test_elasticsearch_cloud.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
import logging
from vectordb_bench.models import (
DB,
MetricType,
ElasticsearchConfig,
)
from vectordb_bench.models import DB
from vectordb_bench.backend.clients.elastic_cloud.config import ElasticCloudConfig
from vectordb_bench.backend.clients import MetricType
import numpy as np


Expand All @@ -13,10 +12,11 @@
password = ""


class TestModels:
class TestElasticsearchCloud:
@pytest.mark.skip(reason="Needs elastic cloud credentials")
def test_insert_and_search(self):
assert DB.ElasticCloud.value == "Elasticsearch"
assert DB.ElasticCloud.config == ElasticsearchConfig
assert DB.ElasticCloud.value == "ElasticCloud"
assert DB.ElasticCloud.config_cls == ElasticCloudConfig

dbcls = DB.ElasticCloud.init_cls
dbConfig = DB.ElasticCloud.config_cls(cloud_id=cloud_id, password=password)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_test_result(self):
result = CaseResult(
task_config=TaskConfig(
db=DB.Milvus,
db_config=DB.Milvus.config(),
db_case_config=DB.Milvus.case_config_cls(index=IndexType.Flat)(),
db_config=DB.Milvus.config_cls(),
db_case_config=DB.Milvus.case_config_cls(index_type=IndexType.Flat)(),
case_config=CaseConfig(case_id=CaseType.Performance10M),
),
metrics=Metric(),
Expand Down
23 changes: 18 additions & 5 deletions vectordb_bench/backend/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from contextlib import contextmanager
from enum import Enum

from pydantic import BaseModel, SecretStr, validator
from pydantic import BaseModel, SecretStr, field_validator

from vectordb_bench.backend.filter import Filter, FilterOp

Expand Down Expand Up @@ -88,11 +88,24 @@ def common_long_configs() -> list[str]:
def to_dict(self) -> dict:
raise NotImplementedError

@validator("*")
def not_empty_field(cls, v: any, field: any):
if field.name in cls.common_short_configs() or field.name in cls.common_long_configs():
@field_validator("*", mode="before")
def not_empty_field(cls, v: any, info): # noqa: ANN001
# Allow empty for known short/long config fields
field_name = getattr(info, "field_name", None)
if field_name in cls.common_short_configs() or field_name in cls.common_long_configs():
return v
if not v and isinstance(v, str | SecretStr):

# For strings and SecretStr, reject empty values
try:
# If it's a SecretStr, check the underlying value
if isinstance(v, SecretStr):
if v.get_secret_value() == "":
raise ValueError("Empty string!")
return v
except Exception: # pragma: no cover - defensive
pass

if isinstance(v, str) and v == "":
raise ValueError("Empty string!")
return v

Expand Down
6 changes: 3 additions & 3 deletions vectordb_bench/backend/clients/doris/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from pydantic import BaseModel, SecretStr, validator
from pydantic import BaseModel, SecretStr, field_validator

from ..api import DBCaseConfig, DBConfig, MetricType

Expand All @@ -17,8 +17,8 @@ class DorisConfig(DBConfig):
db_name: str = "test"
ssl: bool = False

@validator("*")
def not_empty_field(cls, v: any, field: any):
@field_validator("*", mode="before")
def not_empty_field(cls, v: any): # noqa: ANN001
return v

def to_dict(self) -> dict:
Expand Down
6 changes: 3 additions & 3 deletions vectordb_bench/backend/clients/mariadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def parse_metric(self) -> str:


class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
M: int | None
ef_search: int | None
M: int | None = None
ef_search: int | None = None
index: IndexType = IndexType.HNSW
storage_engine: str = "InnoDB"
max_cache_size: int | None
max_cache_size: int | None = None

def index_param(self) -> dict:
return {
Expand Down
18 changes: 6 additions & 12 deletions vectordb_bench/backend/clients/mariadb/mariadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,10 @@ def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]:
def _drop_db(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
log.info(f"{self.name} client drop db : {self.db_name}")

# flush tables before dropping database to avoid some locking issue
self.cursor.execute("FLUSH TABLES")
self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
log.info(f"{self.name} client drop table : {self.table_name}")
self.cursor.execute(f"USE {self.db_name}")
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
self.cursor.execute("COMMIT")
self.cursor.execute("FLUSH TABLES")

def _create_db_table(self, dim: int):
assert self.conn is not None, "Connection is not initialized"
Expand All @@ -67,9 +64,6 @@ def _create_db_table(self, dim: int):
index_param = self.case_config.index_param()

try:
log.info(f"{self.name} client create database : {self.db_name}")
self.cursor.execute(f"CREATE DATABASE {self.db_name}")

log.info(f"{self.name} client create table : {self.table_name}")
self.cursor.execute(f"USE {self.db_name}")

Expand Down Expand Up @@ -112,7 +106,7 @@ def init(self):

self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
self.select_sql = (
f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
f"SELECT id FROM {self.db_name}.{self.table_name} " # noqa: S608
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
)
self.select_sql_with_filter = (
Expand All @@ -131,7 +125,7 @@ def init(self):
def ready_to_load(self) -> bool:
pass

def optimize(self) -> None:
def optimize(self, data_size: int | None = None) -> None:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

Expand Down Expand Up @@ -180,7 +174,7 @@ def insert_embeddings(

self.cursor.executemany(self.insert_sql, batch_data)
self.cursor.execute("COMMIT")
self.cursor.execute("FLUSH TABLES")
# self.cursor.execute("FLUSH TABLES")

return len(metadata), None
except Exception as e:
Expand Down
20 changes: 13 additions & 7 deletions vectordb_bench/backend/clients/milvus/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, SecretStr, validator
from pydantic import BaseModel, SecretStr, field_validator

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType

Expand All @@ -19,15 +19,21 @@ def to_dict(self) -> dict:
"replica_number": self.replica_number,
}

@validator("*")
def not_empty_field(cls, v: any, field: any):
@field_validator("*", mode="before")
def not_empty_field(cls, v: any, info): # noqa: ANN001
field_name = getattr(info, "field_name", None)
if (
field.name in cls.common_short_configs()
or field.name in cls.common_long_configs()
or field.name in ["user", "password"]
field_name in cls.common_short_configs()
or field_name in cls.common_long_configs()
or field_name in ["user", "password"]
):
return v
if isinstance(v, str | SecretStr) and len(v) == 0:
# SecretStr empty check
if isinstance(v, SecretStr):
if v.get_secret_value() == "":
raise ValueError("Empty string!")
return v
if isinstance(v, str) and v == "":
raise ValueError("Empty string!")
return v

Expand Down
15 changes: 14 additions & 1 deletion vectordb_bench/backend/clients/oceanbase/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TypedDict

from pydantic import BaseModel, SecretStr
from pydantic import BaseModel, SecretStr, field_validator

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType

Expand Down Expand Up @@ -31,6 +31,19 @@ def to_dict(self) -> OceanBaseConfigDict:
"database": self.database,
}

@field_validator("*", mode="before")
def not_empty_field(cls, v: any, info): # noqa: ANN001
field_name = getattr(info, "field_name", None)
if field_name in ["password", "host", "db_label"]:
return v
if isinstance(v, SecretStr):
if v.get_secret_value() == "":
raise ValueError("Empty string!")
return v
if isinstance(v, str) and v == "":
raise ValueError("Empty string!")
return v


class OceanBaseIndexConfig(BaseModel):
index: IndexType
Expand Down
35 changes: 20 additions & 15 deletions vectordb_bench/backend/clients/oss_opensearch/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from enum import Enum

from pydantic import BaseModel, SecretStr, root_validator, validator
from pydantic import BaseModel, SecretStr, field_validator, model_validator

from ..api import DBCaseConfig, DBConfig, MetricType

Expand Down Expand Up @@ -32,15 +32,20 @@ def to_dict(self) -> dict:
"timeout": 600,
}

@validator("*")
def not_empty_field(cls, v: any, field: any):
@field_validator("*", mode="before")
def not_empty_field(cls, v: any, info): # noqa: ANN001
field_name = getattr(info, "field_name", None)
if (
field.name in cls.common_short_configs()
or field.name in cls.common_long_configs()
or field.name in ["user", "password", "host"]
field_name in cls.common_short_configs()
or field_name in cls.common_long_configs()
or field_name in ["user", "password", "host"]
):
return v
if isinstance(v, str | SecretStr) and len(v) == 0:
if isinstance(v, SecretStr):
if v.get_secret_value() == "":
raise ValueError("Empty string!")
return v
if isinstance(v, str) and v == "":
raise ValueError("Empty string!")
return v

Expand Down Expand Up @@ -128,19 +133,19 @@ def validate_quantization_type(cls, value: any):

return mapping.get(value, OSSOpenSearchQuantization.NONE)

@root_validator
def validate_engine_name(cls, values: dict):
@model_validator(mode="after")
def validate_engine_name(self): # noqa: D401
"""Map engine_name string from UI to engine enum"""
if values.get("engine_name"):
engine_name = values["engine_name"].lower()
if self.engine_name:
engine_name = self.engine_name.lower()
if engine_name == "faiss":
values["engine"] = OSSOS_Engine.faiss
self.engine = OSSOS_Engine.faiss
elif engine_name == "lucene":
values["engine"] = OSSOS_Engine.lucene
self.engine = OSSOS_Engine.lucene
else:
log.warning(f"Unknown engine_name: {engine_name}, defaulting to faiss")
values["engine"] = OSSOS_Engine.faiss
return values
self.engine = OSSOS_Engine.faiss
return self

def __eq__(self, obj: any):
return (
Expand Down
Loading
Loading