diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index c6b5d1cc3..5a7e969e9 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -1,5 +1,5 @@ -grpcio==1.53.2 -grpcio-tools==1.53.0 +grpcio>=1.53.2 +grpcio-tools>=1.53.0 qdrant-client pinecone-client weaviate-client @@ -19,11 +19,13 @@ psutil polars plotly environs -pydantic=v2 scikit-learn pymilvus clickhouse_connect pyvespa mysql-connector-python packaging -hdrhistogram>=0.10.1 \ No newline at end of file +ujson +numpy +hdrhistogram>=0.10.1 diff --git a/pyproject.toml b/pyproject.toml index ef4792bee..b7cfc6c1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "click", "pytz", "streamlit-autorefresh", - "streamlit<1.44,!=1.34.0", # There is a breaking change in 1.44 related to get_page https://discuss.streamlit.io/t/from-streamlit-source-util-import-get-pages-gone-in-v-1-44-0-need-urgent-help/98399 + "streamlit>=1.52.2", # There is a breaking change in 1.44 related to get_page https://discuss.streamlit.io/t/from-streamlit-source-util-import-get-pages-gone-in-v-1-44-0-need-urgent-help/98399 "streamlit_extras", "tqdm", "s3fs", @@ -37,7 +37,7 @@ dependencies = [ "polars", "plotly", "environs", - "pydantic=2", "scikit-learn", "pymilvus", # with pandas, numpy, ujson "ujson", @@ -54,11 +54,11 @@ test = [ restful = [ "flask" ] all = [ - "grpcio==1.53.0", # for qdrant-client and pymilvus - "grpcio-tools==1.53.0", # for qdrant-client and pymilvus + "grpcio>=1.53.0", # for qdrant-client, pymilvus and weaviate + "grpcio-tools>=1.53.0", # for qdrant-client, pymilvus and weaviate "qdrant-client", "pinecone-client", - "weaviate-client", + "weaviate-client>=4.18.3", "elasticsearch", "sqlalchemy", "redis", diff --git a/tests/test_bench_runner.py b/tests/test_bench_runner.py index 5fab91067..f06fbe2a5 100644 --- a/tests/test_bench_runner.py +++ b/tests/test_bench_runner.py @@ -2,8 +2,9 @@ import logging from vectordb_bench.interface import BenchMarkRunner from vectordb_bench.models import ( - DB, IndexType, CaseType, TaskConfig, CaseConfig, + DB, CaseType, TaskConfig, CaseConfig, ) +from vectordb_bench.backend.clients.api import IndexType log = logging.getLogger(__name__) @@ -19,9 +20,9 @@ def test_performance_case_whole(self): task_config=TaskConfig( db=DB.Milvus, - db_config=DB.Milvus.config(), - db_case_config=DB.Milvus.case_config_cls(index=IndexType.Flat)(), - case_config=CaseConfig(case_id=CaseType.PerformanceSZero), + db_config=DB.Milvus.config_cls(), + db_case_config=DB.Milvus.case_config_cls(index_type=IndexType.Flat)(), + case_config=CaseConfig(case_id=CaseType.Performance768D1M), ) runner.run([task_config]) @@ -34,9 +35,9 @@ def test_performance_case_clean(self): task_config=TaskConfig( db=DB.Milvus, - db_config=DB.Milvus.config(), - db_case_config=DB.Milvus.case_config_cls(index=IndexType.Flat)(), - case_config=CaseConfig(case_id=CaseType.PerformanceSZero), + db_config=DB.Milvus.config_cls(), + db_case_config=DB.Milvus.case_config_cls(index_type=IndexType.Flat)(), + case_config=CaseConfig(case_id=CaseType.Performance768D1M), ) runner.run([task_config]) @@ -46,9 +47,9 @@ def test_performance_case_clean(self): def test_performance_case_no_error(self): task_config=TaskConfig( db=DB.ZillizCloud, - db_config=DB.ZillizCloud.config(uri="xxx", user="abc", password="1234"), + db_config=DB.ZillizCloud.config_cls(uri="xxx", user="abc", password="1234"), db_case_config=DB.ZillizCloud.case_config_cls()(), - case_config=CaseConfig(case_id=CaseType.PerformanceSZero), + case_config=CaseConfig(case_id=CaseType.Performance768D1M), ) t = task_config.copy() diff --git a/tests/test_elasticsearch_cloud.py b/tests/test_elasticsearch_cloud.py index f161ab6c2..919a96e8c 100644 --- a/tests/test_elasticsearch_cloud.py +++ b/tests/test_elasticsearch_cloud.py @@ -1,9 +1,8 @@ +import pytest import logging -from vectordb_bench.models import ( - DB, - MetricType, - ElasticsearchConfig, -) +from vectordb_bench.models import DB +from vectordb_bench.backend.clients.elastic_cloud.config import ElasticCloudConfig +from vectordb_bench.backend.clients import MetricType import numpy as np @@ -13,10 +12,11 @@ password = "" -class TestModels: +class TestElasticsearchCloud: + @pytest.mark.skip(reason="Needs elastic cloud credentials") def test_insert_and_search(self): - assert DB.ElasticCloud.value == "Elasticsearch" - assert DB.ElasticCloud.config == ElasticsearchConfig + assert DB.ElasticCloud.value == "ElasticCloud" + assert DB.ElasticCloud.config_cls == ElasticCloudConfig dbcls = DB.ElasticCloud.init_cls dbConfig = DB.ElasticCloud.config_cls(cloud_id=cloud_id, password=password) diff --git a/tests/test_models.py b/tests/test_models.py index d68dd6afb..0357a6201 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -22,8 +22,8 @@ def test_test_result(self): result = CaseResult( task_config=TaskConfig( db=DB.Milvus, - db_config=DB.Milvus.config(), - db_case_config=DB.Milvus.case_config_cls(index=IndexType.Flat)(), + db_config=DB.Milvus.config_cls(), + db_case_config=DB.Milvus.case_config_cls(index_type=IndexType.Flat)(), case_config=CaseConfig(case_id=CaseType.Performance10M), ), metrics=Metric(), diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 822821b34..f9647e345 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from enum import Enum -from pydantic import BaseModel, SecretStr, validator +from pydantic import BaseModel, SecretStr, field_validator from vectordb_bench.backend.filter import Filter, FilterOp @@ -88,11 +88,24 @@ def common_long_configs() -> list[str]: def to_dict(self) -> dict: raise NotImplementedError - @validator("*") - def not_empty_field(cls, v: any, field: any): - if field.name in cls.common_short_configs() or field.name in cls.common_long_configs(): + @field_validator("*", mode="before") + def not_empty_field(cls, v: any, info): # noqa: ANN001 + # Allow empty for known short/long config fields + field_name = getattr(info, "field_name", None) + if field_name in cls.common_short_configs() or field_name in cls.common_long_configs(): return v - if not v and isinstance(v, str | SecretStr): + + # For strings and SecretStr, reject empty values + try: + # If it's a SecretStr, check the underlying value + if isinstance(v, SecretStr): + if v.get_secret_value() == "": + raise ValueError("Empty string!") + return v + except Exception: # pragma: no cover - defensive + pass + + if isinstance(v, str) and v == "": raise ValueError("Empty string!") return v diff --git a/vectordb_bench/backend/clients/doris/config.py b/vectordb_bench/backend/clients/doris/config.py index a15309922..0b5a7310c 100644 --- a/vectordb_bench/backend/clients/doris/config.py +++ b/vectordb_bench/backend/clients/doris/config.py @@ -1,6 +1,6 @@ import logging -from pydantic import BaseModel, SecretStr, validator +from pydantic import BaseModel, SecretStr, field_validator from ..api import DBCaseConfig, DBConfig, MetricType @@ -17,8 +17,8 @@ class DorisConfig(DBConfig): db_name: str = "test" ssl: bool = False - @validator("*") - def not_empty_field(cls, v: any, field: any): + @field_validator("*", mode="before") + def not_empty_field(cls, v: any): # noqa: ANN001 return v def to_dict(self) -> dict: diff --git a/vectordb_bench/backend/clients/mariadb/config.py b/vectordb_bench/backend/clients/mariadb/config.py index d183adc76..21ea9ac2e 100644 --- a/vectordb_bench/backend/clients/mariadb/config.py +++ b/vectordb_bench/backend/clients/mariadb/config.py @@ -46,11 +46,11 @@ def parse_metric(self) -> str: class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): - M: int | None - ef_search: int | None + M: int | None = None + ef_search: int | None = None index: IndexType = IndexType.HNSW storage_engine: str = "InnoDB" - max_cache_size: int | None + max_cache_size: int | None = None def index_param(self) -> dict: return { diff --git a/vectordb_bench/backend/clients/mariadb/mariadb.py b/vectordb_bench/backend/clients/mariadb/mariadb.py index 5ccddfe7a..104e4314d 100644 --- a/vectordb_bench/backend/clients/mariadb/mariadb.py +++ b/vectordb_bench/backend/clients/mariadb/mariadb.py @@ -52,13 +52,10 @@ def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]: def _drop_db(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" - log.info(f"{self.name} client drop db : {self.db_name}") - - # flush tables before dropping database to avoid some locking issue - self.cursor.execute("FLUSH TABLES") - self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}") + log.info(f"{self.name} client drop table : {self.table_name}") + self.cursor.execute(f"USE {self.db_name}") + self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") self.cursor.execute("COMMIT") - self.cursor.execute("FLUSH TABLES") def _create_db_table(self, dim: int): assert self.conn is not None, "Connection is not initialized" @@ -67,9 +64,6 @@ def _create_db_table(self, dim: int): index_param = self.case_config.index_param() try: - log.info(f"{self.name} client create database : {self.db_name}") - self.cursor.execute(f"CREATE DATABASE {self.db_name}") - log.info(f"{self.name} client create table : {self.table_name}") self.cursor.execute(f"USE {self.db_name}") @@ -112,7 +106,7 @@ def init(self): self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608 self.select_sql = ( - f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608 + f"SELECT id FROM {self.db_name}.{self.table_name} " # noqa: S608 f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d" ) self.select_sql_with_filter = ( @@ -131,7 +125,7 @@ def init(self): def ready_to_load(self) -> bool: pass - def optimize(self) -> None: + def optimize(self, data_size: int | None = None) -> None: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -180,7 +174,7 @@ def insert_embeddings( self.cursor.executemany(self.insert_sql, batch_data) self.cursor.execute("COMMIT") - self.cursor.execute("FLUSH TABLES") + # self.cursor.execute("FLUSH TABLES") return len(metadata), None except Exception as e: diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 9ffbdcece..2bf364b2d 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, SecretStr, validator +from pydantic import BaseModel, SecretStr, field_validator from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType @@ -19,15 +19,21 @@ def to_dict(self) -> dict: "replica_number": self.replica_number, } - @validator("*") - def not_empty_field(cls, v: any, field: any): + @field_validator("*", mode="before") + def not_empty_field(cls, v: any, info): # noqa: ANN001 + field_name = getattr(info, "field_name", None) if ( - field.name in cls.common_short_configs() - or field.name in cls.common_long_configs() - or field.name in ["user", "password"] + field_name in cls.common_short_configs() + or field_name in cls.common_long_configs() + or field_name in ["user", "password"] ): return v - if isinstance(v, str | SecretStr) and len(v) == 0: + # SecretStr empty check + if isinstance(v, SecretStr): + if v.get_secret_value() == "": + raise ValueError("Empty string!") + return v + if isinstance(v, str) and v == "": raise ValueError("Empty string!") return v diff --git a/vectordb_bench/backend/clients/oceanbase/config.py b/vectordb_bench/backend/clients/oceanbase/config.py index 1f37cfc75..0bcf315fe 100644 --- a/vectordb_bench/backend/clients/oceanbase/config.py +++ b/vectordb_bench/backend/clients/oceanbase/config.py @@ -1,6 +1,6 @@ from typing import TypedDict -from pydantic import BaseModel, SecretStr +from pydantic import BaseModel, SecretStr, field_validator from ..api import DBCaseConfig, DBConfig, IndexType, MetricType @@ -31,6 +31,19 @@ def to_dict(self) -> OceanBaseConfigDict: "database": self.database, } + @field_validator("*", mode="before") + def not_empty_field(cls, v: any, info): # noqa: ANN001 + field_name = getattr(info, "field_name", None) + if field_name in ["password", "host", "db_label"]: + return v + if isinstance(v, SecretStr): + if v.get_secret_value() == "": + raise ValueError("Empty string!") + return v + if isinstance(v, str) and v == "": + raise ValueError("Empty string!") + return v + class OceanBaseIndexConfig(BaseModel): index: IndexType diff --git a/vectordb_bench/backend/clients/oss_opensearch/config.py b/vectordb_bench/backend/clients/oss_opensearch/config.py index 83fed3d58..5314fee29 100644 --- a/vectordb_bench/backend/clients/oss_opensearch/config.py +++ b/vectordb_bench/backend/clients/oss_opensearch/config.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from pydantic import BaseModel, SecretStr, root_validator, validator +from pydantic import BaseModel, SecretStr, field_validator, model_validator from ..api import DBCaseConfig, DBConfig, MetricType @@ -32,15 +32,20 @@ def to_dict(self) -> dict: "timeout": 600, } - @validator("*") - def not_empty_field(cls, v: any, field: any): + @field_validator("*", mode="before") + def not_empty_field(cls, v: any, info): # noqa: ANN001 + field_name = getattr(info, "field_name", None) if ( - field.name in cls.common_short_configs() - or field.name in cls.common_long_configs() - or field.name in ["user", "password", "host"] + field_name in cls.common_short_configs() + or field_name in cls.common_long_configs() + or field_name in ["user", "password", "host"] ): return v - if isinstance(v, str | SecretStr) and len(v) == 0: + if isinstance(v, SecretStr): + if v.get_secret_value() == "": + raise ValueError("Empty string!") + return v + if isinstance(v, str) and v == "": raise ValueError("Empty string!") return v @@ -128,19 +133,19 @@ def validate_quantization_type(cls, value: any): return mapping.get(value, OSSOpenSearchQuantization.NONE) - @root_validator - def validate_engine_name(cls, values: dict): + @model_validator(mode="after") + def validate_engine_name(self): # noqa: D401 """Map engine_name string from UI to engine enum""" - if values.get("engine_name"): - engine_name = values["engine_name"].lower() + if self.engine_name: + engine_name = self.engine_name.lower() if engine_name == "faiss": - values["engine"] = OSSOS_Engine.faiss + self.engine = OSSOS_Engine.faiss elif engine_name == "lucene": - values["engine"] = OSSOS_Engine.lucene + self.engine = OSSOS_Engine.lucene else: log.warning(f"Unknown engine_name: {engine_name}, defaulting to faiss") - values["engine"] = OSSOS_Engine.faiss - return values + self.engine = OSSOS_Engine.faiss + return self def __eq__(self, obj: any): return ( diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index f8c138802..036f0ea6c 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -164,7 +164,26 @@ def PgVectorIVFFlat( ) -class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1): ... +class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1): + create_index_before_load: Annotated[ + bool | None, + click.option( + "--create-index-before-load/--no-create-index-before-load", + type=bool, + help="Create HNSW index before loading data (overrides config.yml if provided)", + required=False, + ), + ] + + create_index_after_load: Annotated[ + bool | None, + click.option( + "--create-index-after-load/--no-create-index-after-load", + type=bool, + help="Create HNSW index after loading data (overrides config.yml if provided)", + required=False, + ), + ] @cli.command() @@ -196,6 +215,8 @@ def PgVectorHNSW( reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], + create_index_before_load=parameters.get("create_index_before_load"), + create_index_after_load=parameters.get("create_index_after_load"), ), **parameters, ) diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 98e82f1c2..ae6511907 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -175,13 +175,13 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig): a good place to start is sqrt(lists) """ - lists: int | None - probes: int | None + lists: int | None = None + probes: int | None = None index: IndexType = IndexType.ES_IVFFlat maintenance_work_mem: str | None = None max_parallel_workers: int | None = None quantization_type: str | None = None - table_quantization_type: str | None + table_quantization_type: str | None = None reranking: bool | None = None quantized_fetch_limit: int | None = None reranking_metric: str | None = None @@ -224,14 +224,14 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): created without any data in the table since there isn't a training step like IVFFlat. """ - m: int | None # DETAIL: Valid values are between "2" and "100". - ef_construction: int | None # ef_construction must be greater than or equal to 2 * m - ef_search: int | None + m: int | None = None # DETAIL: Valid values are between "2" and "100". + ef_construction: int | None = None # ef_construction must be greater than or equal to 2 * m + ef_search: int | None = None index: IndexType = IndexType.ES_HNSW maintenance_work_mem: str | None = None max_parallel_workers: int | None = None quantization_type: str | None = None - table_quantization_type: str | None + table_quantization_type: str | None = None reranking: bool | None = None quantized_fetch_limit: int | None = None reranking_metric: str | None = None diff --git a/vectordb_bench/backend/clients/qdrant_cloud/config.py b/vectordb_bench/backend/clients/qdrant_cloud/config.py index b2eeb2ce6..8bfcc8aed 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/config.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/config.py @@ -1,6 +1,6 @@ from typing import TypeVar -from pydantic import BaseModel, SecretStr, validator +from pydantic import BaseModel, SecretStr, field_validator from ..api import DBCaseConfig, DBConfig, MetricType @@ -25,11 +25,11 @@ def to_dict(self) -> dict: "url": self.url.get_secret_value(), } - @validator("*") - def not_empty_field(cls, v: any, field: any): - if field.name in ["api_key"]: + @field_validator("*", mode="before") + def not_empty_field(cls, v: any, info): # noqa: ANN001 + if getattr(info, "field_name", None) in ["api_key"]: return v - return super().not_empty_field(v, field) + return super().not_empty_field(v, info) class QdrantIndexConfig(BaseModel, DBCaseConfig): diff --git a/vectordb_bench/backend/clients/tidb/config.py b/vectordb_bench/backend/clients/tidb/config.py index 71fdbad66..5d0f7de8a 100644 --- a/vectordb_bench/backend/clients/tidb/config.py +++ b/vectordb_bench/backend/clients/tidb/config.py @@ -1,6 +1,6 @@ from typing import TypedDict -from pydantic import BaseModel, SecretStr, validator +from pydantic import BaseModel, SecretStr, field_validator from ..api import DBCaseConfig, DBConfig, MetricType @@ -35,11 +35,16 @@ def to_dict(self) -> TiDBConfigDict: "ssl_verify_identity": self.ssl, } - @validator("*") - def not_empty_field(cls, v: any, field: any): - if field.name in ["password", "db_label"]: + @field_validator("*", mode="before") + def not_empty_field(cls, v: any, info): # noqa: ANN001 + field_name = getattr(info, "field_name", None) + if field_name in ["password", "db_label"]: return v - if isinstance(v, str | SecretStr) and len(v) == 0: + if isinstance(v, SecretStr): + if v.get_secret_value() == "": + raise ValueError("Empty string!") + return v + if isinstance(v, str) and v == "": raise ValueError("Empty string!") return v diff --git a/vectordb_bench/backend/clients/weaviate_cloud/cli.py b/vectordb_bench/backend/clients/weaviate_cloud/cli.py index cba3c2377..46b54c995 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/cli.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/cli.py @@ -19,7 +19,19 @@ class WeaviateTypedDict(CommonTypedDict): ] url: Annotated[ str, - click.option("--url", type=str, help="Weaviate url", required=True), + # HTTP endpoint for Weaviate (required). Do not pass the gRPC port here. + click.option("--url", type=str, help="Weaviate HTTP url (e.g. http://localhost:8080)", required=True), + ] + grpc_url: Annotated[ + str, + # Optional: allow providing a gRPC address separately; it is currently not used by the Python client. + click.option( + "--grpc-url", + type=str, + required=False, + default="", + help="Optional Weaviate gRPC address (e.g. localhost:50051). Not used by this runner.", + ), ] no_auth: Annotated[ bool, @@ -49,13 +61,19 @@ class WeaviateTypedDict(CommonTypedDict): def Weaviate(**parameters: Unpack[WeaviateTypedDict]): from .config import WeaviateConfig, WeaviateIndexConfig + # Guard: ensure the HTTP URL includes a scheme to avoid requests InvalidSchema errors + http_url = parameters["url"] + if not (http_url.startswith("http://") or http_url.startswith("https://")): + http_url = f"http://{http_url}" + run( db=DB.WeaviateCloud, db_config=WeaviateConfig( db_label=parameters["db_label"], api_key=SecretStr(parameters["api_key"]) if parameters["api_key"] != "" else SecretStr("-"), - url=SecretStr(parameters["url"]), + url=SecretStr(http_url), no_auth=parameters["no_auth"], + grpc_url=SecretStr(parameters["grpc_url"]) if parameters.get("grpc_url") else None, ), db_case_config=WeaviateIndexConfig( efConstruction=parameters["ef_construction"], diff --git a/vectordb_bench/backend/clients/weaviate_cloud/config.py b/vectordb_bench/backend/clients/weaviate_cloud/config.py index f29a307a3..7a18b4ff3 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/config.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/config.py @@ -7,14 +7,45 @@ class WeaviateConfig(DBConfig): url: SecretStr api_key: SecretStr no_auth: bool | None = False + # optional gRPC endpoint like "localhost:50051" + grpc_url: SecretStr | None = None + # Backward-compat method used by older code paths; keep for compatibility def to_dict(self) -> dict: return { "url": self.url.get_secret_value(), "auth_client_secret": self.api_key.get_secret_value(), "no_auth": self.no_auth, + "grpc_url": self.grpc_url.get_secret_value() if self.grpc_url else None, } + # Helpers for v4 client wiring + def host_port(self) -> tuple[str, int]: + from urllib.parse import urlparse + + url = self.url.get_secret_value() + u = urlparse(url) + host = u.hostname or "localhost" + if u.port: + port = u.port + else: + port = 443 if (u.scheme or "http").lower() == "https" else 80 + return host, port + + def grpc_host_port(self) -> tuple[str, int] | None: + if not self.grpc_url: + return None + value = self.grpc_url.get_secret_value() + if not value: + return None + if ":" in value: + host, port_str = value.split(":", 1) + try: + return host, int(port_str) + except ValueError: + return host, 50051 + return value, 50051 + class WeaviateIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py index d6111c8da..746cc8db1 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py @@ -6,8 +6,11 @@ import weaviate from weaviate.exceptions import WeaviateBaseError +from weaviate.classes.config import Configure, DataType, Property, VectorDistances, Reconfigure from ..api import DBCaseConfig, VectorDB +from .config import WeaviateConfig +from pydantic import SecretStr log = logging.getLogger(__name__) @@ -16,21 +19,15 @@ class WeaviateCloud(VectorDB): def __init__( self, dim: int, - db_config: dict, + db_config, db_case_config: DBCaseConfig, collection_name: str = "VectorDBBenchCollection", drop_old: bool = False, **kwargs, ): - """Initialize wrapper around the weaviate vector database.""" - db_config.update( - { - "auth_client_secret": weaviate.AuthApiKey( - api_key=db_config.get("auth_client_secret"), - ), - }, - ) - self.db_config = db_config + """Initialize wrapper around the Weaviate vector database (v4 client).""" + # Normalize config to WeaviateConfig model (accept dict for backward compatibility) + self.cfg = self._ensure_cfg(db_config) self.case_config = db_case_config self.collection_name = collection_name @@ -38,24 +35,26 @@ def __init__( self._vector_field = "vector" self._index_name = "vector_idx" - # If local setup is used, we - if db_config["no_auth"]: - del db_config["auth_client_secret"] - del db_config["no_auth"] - - from weaviate import Client - - client = Client(**db_config) - if drop_old: - try: - if client.schema.exists(self.collection_name): - log.info(f"weaviate client drop_old collection: {self.collection_name}") - client.schema.delete_class(self.collection_name) - except WeaviateBaseError as e: - log.warning(f"Failed to drop collection: {self.collection_name} error: {e!s}") - raise e from None - self._create_collection(client) - client = None + # Open a short-lived admin connection to ensure collection exists + http_host, http_port = self.cfg.host_port() + grpc_pair = getattr(self.cfg, "grpc_host_port", None) + grpc_host, grpc_port = self.cfg.grpc_host_port() if grpc_pair else (None, None) + + client = weaviate.connect_to_custom( + http_host=http_host, + http_port=http_port, + grpc_host=grpc_host, + grpc_port=grpc_port, + http_secure=False, + grpc_secure=False, + ) + try: + if drop_old and client.collections.exists(self.collection_name): + log.info(f"weaviate client drop_old collection: {self.collection_name}") + client.collections.delete(self.collection_name) + self._create_collection(client) + finally: + client.close() @contextmanager def init(self) -> None: @@ -65,67 +64,138 @@ def init(self) -> None: >>> self.insert_embeddings() >>> self.search_embedding() """ - from weaviate import Client + http_host, http_port = self.cfg.host_port() + grpc_pair = getattr(self.cfg, "grpc_host_port", None) + grpc_host, grpc_port = self.cfg.grpc_host_port() if grpc_pair else (None, None) + + self.client = weaviate.connect_to_custom( + http_host=http_host, + http_port=http_port, + grpc_host=grpc_host, + grpc_port=grpc_port, + http_secure=False, + grpc_secure=False, + ) + try: + yield + finally: + self.client.close() + self.client = None - self.client = Client(**self.db_config) - yield - self.client = None - del self.client + @staticmethod + def _ensure_cfg(db_config) -> WeaviateConfig: + """Accept either a WeaviateConfig instance or a plain dict (legacy path). + + When a dict is provided, reconstruct WeaviateConfig and normalize fields. + """ + if isinstance(db_config, WeaviateConfig): + return db_config + if isinstance(db_config, dict): + # Support both keys: 'api_key' and legacy 'auth_client_secret' + api_key_val = db_config.get("api_key") or db_config.get("auth_client_secret") or "-" + url_val = db_config.get("url", "") + grpc_val = db_config.get("grpc_url") + no_auth_val = bool(db_config.get("no_auth", False)) + + # Ensure HTTP scheme for url + if isinstance(url_val, str) and not (url_val.startswith("http://") or url_val.startswith("https://")): + url_val = f"http://{url_val}" if url_val else "http://localhost:8080" + + return WeaviateConfig( + db_label=db_config.get("db_label", "weaviate"), + api_key=SecretStr(str(api_key_val) if api_key_val is not None else "-"), + url=SecretStr(url_val), + no_auth=no_auth_val, + grpc_url=SecretStr(str(grpc_val)) if grpc_val else None, + ) + # Unexpected type; raise for clarity + raise TypeError(f"Unsupported db_config type for WeaviateCloud: {type(db_config)}") def optimize(self, data_size: int | None = None): - assert self.client.schema.exists(self.collection_name) - self.client.schema.update_config( - self.collection_name, - {"vectorIndexConfig": self.case_config.search_param()}, - ) + col = self.client.collections.get(self.collection_name) + # Update search ef when provided using v4 Configure helper + try: + ef_val = self.case_config.search_param().get("ef") + except Exception: + ef_val = None + if ef_val is not None: + col.config.update( + vector_config=Reconfigure.Vectors.update( + name="default", + vector_index_config=Reconfigure.VectorIndex.hnsw( + ef=ef_val, + ), + ), + ) - def _create_collection(self, client: weaviate.Client) -> None: - if not client.schema.exists(self.collection_name): + def _create_collection(self, client) -> None: + if not client.collections.exists(self.collection_name): log.info(f"Create collection: {self.collection_name}") - class_obj = { - "class": self.collection_name, - "vectorizer": "none", - "properties": [ - { - "dataType": ["int"], - "name": self._scalar_field, - }, - ], - } - class_obj["vectorIndexConfig"] = self.case_config.index_param() + + # Map metric to Weaviate v4 distance enum try: - client.schema.create_class(class_obj) - except WeaviateBaseError as e: - log.warning(f"Failed to create collection: {self.collection_name} error: {e!s}") - raise e from None + metric = self.case_config.parse_metric() + except Exception: + metric = "cosine" + distance_enum = { + "cosine": VectorDistances.COSINE, + "dot": VectorDistances.DOT, + "l2-squared": VectorDistances.L2_SQUARED, + }.get(metric, VectorDistances.COSINE) + + # Optional HNSW params + ef_construction = getattr(self.case_config, "efConstruction", None) + max_connections = getattr(self.case_config, "maxConnections", None) + + vector_index_cfg = Configure.VectorIndex.hnsw( + distance_metric=distance_enum, + **({"ef_construction": ef_construction} if ef_construction is not None else {}), + **({"max_connections": max_connections} if max_connections is not None else {}), + ) + + # Build v4.18.3-compliant vector_config using helper for self-provided vectors + vector_config = Configure.Vectors.self_provided(vector_index_config=vector_index_cfg) + + client.collections.create( + name=self.collection_name, + properties=[Property(name=self._scalar_field, data_type=DataType.INT)], + vector_config=vector_config, + ) def insert_embeddings( self, embeddings: Iterable[list[float]], metadata: list[int], **kwargs, - ) -> tuple[int, Exception]: + ) -> tuple[int, Exception | None]: """Insert embeddings into Weaviate""" - assert self.client.schema.exists(self.collection_name) - insert_count = 0 + col = self.client.collections.get(self.collection_name) try: - with self.client.batch as batch: - batch.batch_size = len(metadata) - batch.dynamic = True - res = [] - for i in range(len(metadata)): - res.append( - batch.add_data_object( - {self._scalar_field: metadata[i]}, - class_name=self.collection_name, - vector=embeddings[i], - ), - ) - insert_count += 1 - return (len(res), None) + # v4.18 expects vectors to be supplied separately from properties. + props_list = [{self._scalar_field: metadata[i]} for i in range(len(metadata))] + vecs_list = [embeddings[i] for i in range(len(metadata))] + + # Stream with fixed-size batches to avoid large single RPCs and reduce backpressure + batch_size = kwargs.get("_weaviate_batch_size", 1000) + concurrent_requests = kwargs.get("_weaviate_concurrency", 2) + + inserted = 0 + idx = 0 + total = len(props_list) + while idx < total: + end = min(idx + batch_size, total) + with col.batch.fixed_size(batch_size=(end - idx), concurrent_requests=concurrent_requests) as batch: + for i in range(idx, end): + batch.add_object(properties=props_list[i], vector=vecs_list[i]) + inserted += 1 + # Lightweight progress log every 50k rows + if inserted % 50000 == 0: + log.info(f"Weaviate inserted {inserted}/{total} objects...") + idx = end + return (inserted, None) except WeaviateBaseError as e: log.warning(f"Failed to insert data, error: {e!s}") - return (insert_count, e) + return (0, e) def search_embedding( self, @@ -134,27 +204,19 @@ def search_embedding( filters: dict | None = None, timeout: int | None = None, ) -> list[int]: - """Perform a search on a query embedding and return results with distance. + """Perform a search on a query embedding and return result keys. Should call self.init() first. """ - assert self.client.schema.exists(self.collection_name) + col = self.client.collections.get(self.collection_name) - query_obj = ( - self.client.query.get(self.collection_name, [self._scalar_field]) - .with_additional("distance") - .with_near_vector({"vector": query}) - .with_limit(k) - ) - if filters: - where_filter = { - "path": "key", + where = None + if filters and "id" in filters: + where = { "operator": "GreaterThanEqual", - "valueInt": filters.get("id"), + "path": [self._scalar_field], + "valueInt": int(filters["id"]), } - query_obj = query_obj.with_where(where_filter) - - # Perform the search. - res = query_obj.do() - # Organize results. - return [result[self._scalar_field] for result in res["data"]["Get"][self.collection_name]] + # weaviate-client v4.18.3: pass the query vector positionally + res = col.query.near_vector(query, limit=k, filters=where) + return [obj.properties.get(self._scalar_field) for obj in res.objects] diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index 552b73417..c2de9a3b6 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -12,7 +12,7 @@ import pandas as pd import polars as pl from pyarrow.parquet import ParquetFile -from pydantic import PrivateAttr, validator +from pydantic import PrivateAttr, field_validator from vectordb_bench import config from vectordb_bench.base import BaseModel @@ -57,10 +57,33 @@ class BaseDataset(BaseModel): gt_id_field: str = "id" gt_neighbors_field: str = "neighbors_id" - @validator("size") + @field_validator("size", mode="before") def verify_size(cls, v: int): - if v not in cls._size_label: - msg = f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}" + # In Pydantic v2, accessing a PrivateAttr on the class returns a ModelPrivateAttr, + # not the default dict. Resolve the actual mapping safely. + def _get_size_label_map() -> dict[int, SizeLabel]: + # If subclasses override `_size_label` with a dict directly, use it. + direct = getattr(cls, "_size_label", None) + if isinstance(direct, dict): + return direct + + # Otherwise, fetch the PrivateAttr default/default_factory. + pa = getattr(cls, "__private_attributes__", {}).get("_size_label") + if pa is not None: + default = getattr(pa, "default", None) + if isinstance(default, dict): + return default + default_factory = getattr(pa, "default_factory", None) + if callable(default_factory): + try: + return default_factory() + except Exception: # pragma: no cover — defensive + return {} + return {} + + size_map = _get_size_label_map() + if v not in size_map: + msg = f"Size {v} not supported for the dataset, expected: {list(size_map.keys())}" raise ValueError(msg) return v @@ -102,7 +125,7 @@ class CustomDataset(BaseDataset): scalar_labels_file: str = "scalar_labels.parquet" label_percentages: list[float] = [] - @validator("size") + @field_validator("size", mode="before") def verify_size(cls, v: int): return v diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 8224a0415..35bde0de4 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -73,7 +73,7 @@ def __hash__(self) -> int: ) def display(self) -> dict: - c_dict = self.ca.dict( + c_dict = self.ca.model_dump( include={ "label": True, "name": True, @@ -122,9 +122,13 @@ def init_db(self, drop_old: bool = True) -> None: if "collection_name" in db_config_dict and not collection_name: collection_name = db_config_dict.pop("collection_name") + # Minimal fix for Weaviate v4 path: pass the Pydantic model instance directly, + # because the new client expects helper methods like host_port()/grpc_host_port(). + db_init_cfg = self.config.db_config if self.config.db == DB.WeaviateCloud else db_config_dict + self.db = db_cls( dim=self.ca.dataset.data.dim, - db_config=db_config_dict, + db_config=db_init_cfg, db_case_config=self.config.db_case_config, drop_old=drop_old, with_scalar_labels=self.ca.with_scalar_labels, diff --git a/vectordb_bench/cli/batch_cli.py b/vectordb_bench/cli/batch_cli.py index 5ac2b1cf1..71baf9e81 100644 --- a/vectordb_bench/cli/batch_cli.py +++ b/vectordb_bench/cli/batch_cli.py @@ -56,32 +56,67 @@ def build_sub_cmd_args(batch_config: MutableMapping[str, Any] | None): "dry_run": False, "custom_dataset_use_shuffled": True, "custom_dataset_with_gt": True, + # weaviate: --no-auth + "no_auth": False, } def format_option(key: str, value: Any): opt_name = key.replace("_", "-") - if key in bool_options: - return format_bool_option(opt_name, value, skip=False) + # Known boolean flags that have explicit negative counterparts + neg_flag_map: dict[str, tuple[str, str]] = { + # General stages + "drop_old": ("drop-old", "skip-drop-old"), + "load": ("load", "skip-load"), + "search_serial": ("search-serial", "skip-search-serial"), + "search_concurrent": ("search-concurrent", "skip-search-concurrent"), + # PgVector + "reranking": ("reranking", "skip-reranking"), + "create_index_before_load": ("create-index-before-load", "no-create-index-before-load"), + "create_index_after_load": ("create-index-after-load", "no-create-index-after-load"), + } + + # Special-case: boolean flags + if isinstance(value, bool): + # weaviate no_auth behaves as a simple positive flag (no negative counterpart) + if key == "no_auth": + return [f"--{opt_name}"] if value else [] + + if key in neg_flag_map: + pos, neg = neg_flag_map[key] + return [f"--{pos}"] if value else [f"--{neg}"] + + # Fallback: for known stage booleans handled above, or simple on/off flags without negative + if key in bool_options: + return format_bool_option(opt_name, value, skip=False) + + # Default fallback: True -> --flag, False -> omit + return [f"--{opt_name}"] if value else [] if key.startswith("skip_"): raw_key = key[5:] raw_opt = raw_key.replace("_", "-") return format_bool_option(raw_opt, value, skip=True, raw_key=raw_key) + # Non-boolean: pass as --key value return [f"--{opt_name}", str(value)] def format_bool_option(opt_name: str, value: Any, skip: bool = False, raw_key: str | None = None): + # Helper kept for backward compatibility with existing stage flags if isinstance(value, bool): if skip: if bool_options.get(raw_key, False): + # When skip_ is provided and the raw_key is a known stage flag, + # emit the proper --skip- or its positive counterpart without values return [f"--skip-{opt_name}"] if value else [f"--{opt_name}"] - return [f"--{opt_name}", str(value)] + # Unknown skip_ keys: do not append literal True/False + return [f"--skip-{opt_name}"] if value else [] if value: return [f"--{opt_name}"] if bool_options.get(opt_name.replace("-", "_"), False): return [f"--skip-{opt_name}"] return [] + # Non-boolean falls back to standard formatting elsewhere return [f"--{opt_name}", str(value)] args_arr = [] diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 12bb4be9b..ba0c1b479 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -653,5 +653,14 @@ def run( if global_result_future: wait([global_result_future]) - while benchmark_runner.has_running(): - time.sleep(1) + try: + while benchmark_runner.has_running(): + time.sleep(1) + except Exception as e: + # Be resilient to IPC issues (e.g., BrokenPipe on Windows) so that + # batch runs can continue with subsequent subcommands. + log.warning(f"Encountered error while monitoring task progress: {e}") + try: + benchmark_runner.stop_running() + except Exception: # noqa: BLE001 + pass diff --git a/vectordb_bench/frontend/components/check_results/charts.py b/vectordb_bench/frontend/components/check_results/charts.py index 0e74d2752..d36cc5fc6 100644 --- a/vectordb_bench/frontend/components/check_results/charts.py +++ b/vectordb_bench/frontend/components/check_results/charts.py @@ -153,4 +153,4 @@ def drawMetricChart(data, metric, st, key: str): ), ) - chart.plotly_chart(fig, use_container_width=True, key=key) + chart.plotly_chart(fig, width="stretch", key=key) diff --git a/vectordb_bench/frontend/components/check_results/filters.py b/vectordb_bench/frontend/components/check_results/filters.py index 6016c0040..7fc38d585 100644 --- a/vectordb_bench/frontend/components/check_results/filters.py +++ b/vectordb_bench/frontend/components/check_results/filters.py @@ -44,11 +44,14 @@ def getshownResults( st.write("There are no results to display. Please wait for the task to complete or run a new task.") return [] + # Filter out default labels that are not in the options + filtered_defaults = [label for label in default_selected_task_labels if label in resultSelectOptions] + selectedResultSelectedOptions = st.multiselect( "Select the task results you need to analyze.", resultSelectOptions, # label_visibility="hidden", - default=default_selected_task_labels or resultSelectOptions, + default=filtered_defaults or resultSelectOptions, ) selectedResult: list[CaseResult] = [] for option in selectedResultSelectedOptions: diff --git a/vectordb_bench/frontend/components/check_results/priceTable.py b/vectordb_bench/frontend/components/check_results/priceTable.py index f2c0ae001..f34f70872 100644 --- a/vectordb_bench/frontend/components/check_results/priceTable.py +++ b/vectordb_bench/frontend/components/check_results/priceTable.py @@ -27,7 +27,7 @@ def priceTable(container, data): expander = container.expander("Price List (Editable).") editTable = expander.data_editor( table, - use_container_width=True, + width="stretch", hide_index=True, height=height, disabled=("DB", "Label"), diff --git a/vectordb_bench/frontend/components/concurrent/charts.py b/vectordb_bench/frontend/components/concurrent/charts.py index 004fcb261..5369d5912 100644 --- a/vectordb_bench/frontend/components/concurrent/charts.py +++ b/vectordb_bench/frontend/components/concurrent/charts.py @@ -94,4 +94,4 @@ def drawChart(data, st, key: str, x_metric: str = "latency_p99", y_metric: str = fig.update_yaxes(range=yrange, title_text=gen_title(y_metric)) fig.update_traces(textposition="bottom right", texttemplate="conc-%{text:,.4~r}") - st.plotly_chart(fig, use_container_width=True, key=key) + st.plotly_chart(fig, width="stretch", key=key) diff --git a/vectordb_bench/frontend/components/int_filter/charts.py b/vectordb_bench/frontend/components/int_filter/charts.py index 881681031..5c32a089e 100644 --- a/vectordb_bench/frontend/components/int_filter/charts.py +++ b/vectordb_bench/frontend/components/int_filter/charts.py @@ -57,4 +57,4 @@ def drawChart(st, data: list[object], metric): margin=dict(l=0, r=0, t=40, b=0, pad=8), legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""), ) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") diff --git a/vectordb_bench/frontend/components/label_filter/charts.py b/vectordb_bench/frontend/components/label_filter/charts.py index 881681031..5c32a089e 100644 --- a/vectordb_bench/frontend/components/label_filter/charts.py +++ b/vectordb_bench/frontend/components/label_filter/charts.py @@ -57,4 +57,4 @@ def drawChart(st, data: list[object], metric): margin=dict(l=0, r=0, t=40, b=0, pad=8), legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""), ) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") diff --git a/vectordb_bench/frontend/components/qps_recall/charts.py b/vectordb_bench/frontend/components/qps_recall/charts.py index ab57dd0ce..a44c51a36 100644 --- a/vectordb_bench/frontend/components/qps_recall/charts.py +++ b/vectordb_bench/frontend/components/qps_recall/charts.py @@ -115,4 +115,4 @@ def drawlinechart(st, data: list[object], metric, key: str): margin=dict(l=0, r=0, t=40, b=0, pad=8), legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""), ) - st.plotly_chart(fig, use_container_width=True, key=key) + st.plotly_chart(fig, width="stretch", key=key) diff --git a/vectordb_bench/frontend/components/streaming/charts.py b/vectordb_bench/frontend/components/streaming/charts.py index a05da9b25..09357cf44 100644 --- a/vectordb_bench/frontend/components/streaming/charts.py +++ b/vectordb_bench/frontend/components/streaming/charts.py @@ -119,7 +119,7 @@ def drawLineChart( if x_metric == DisplayedMetric.search_time: x_title = "Actual Time (s)" fig.update_layout(xaxis_title=x_title) - st.plotly_chart(fig, use_container_width=True, key=key) + st.plotly_chart(fig, width="stretch", key=key) def get_normal_scatter( @@ -234,7 +234,7 @@ def drawBarChart( fig.update_layout(xaxis_title="time (s)") fig.update_layout(barmode="stack") fig.update_traces(width=0.15) - st.plotly_chart(fig, use_container_width=True, key=key) + st.plotly_chart(fig, width="stretch", key=key) def get_bar( diff --git a/vectordb_bench/interface.py b/vectordb_bench/interface.py index 42dc876b0..f988ebb2b 100644 --- a/vectordb_bench/interface.py +++ b/vectordb_bench/interface.py @@ -106,21 +106,31 @@ def get_results(result_dir: pathlib.Path | None = None) -> list[TestResult]: return ResultCollector.collect(target_dir) def _try_get_signal(self): - while self.receive_conn and self.receive_conn.poll(): - sig, received = self.receive_conn.recv() - log.debug(f"Sigal received to process: {sig}, {received}") - if sig == SIGNAL.ERROR: - self.latest_error = received + try: + while self.receive_conn and self.receive_conn.poll(): + sig, received = self.receive_conn.recv() + log.debug(f"Sigal received to process: {sig}, {received}") + if sig == SIGNAL.ERROR: + self.latest_error = received + self._clear_running_task() + elif sig == SIGNAL.SUCCESS: + global global_result_future + global_result_future = None + self.running_task = None + self.receive_conn = None + elif sig == SIGNAL.WIP: + self.running_task.set_finished(received) + else: + self._clear_running_task() + except (BrokenPipeError, EOFError, OSError) as e: + # On Windows, the child process may terminate and close the pipe before + # the parent polls/receives. Treat this as end-of-run and clean up so + # the CLI can proceed instead of crashing with BrokenPipeError. + log.warning(f"Signal pipe broken while polling/receiving: {e}. Cleaning up current task.") + try: self._clear_running_task() - elif sig == SIGNAL.SUCCESS: - global global_result_future - global_result_future = None - self.running_task = None + finally: self.receive_conn = None - elif sig == SIGNAL.WIP: - self.running_task.set_finished(received) - else: - self._clear_running_task() def has_running(self) -> bool: """check if there're running benchmarks""" diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 44aff6a79..ffade94cf 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -295,7 +295,8 @@ def write_db_file(self, result_dir: pathlib.Path, partial: Self, db: str): log.info(f"write results to disk {result_file}") with pathlib.Path(result_file).open("w") as f: - b = partial.json(exclude={"db_config": {"password", "api_key"}}) + # Pydantic v2: use model_dump_json instead of deprecated .json() + b = partial.model_dump_json(exclude={"db_config": {"password", "api_key"}}) f.write(b) def get_case_config(case_config: CaseConfig) -> dict[CaseConfig]: @@ -328,7 +329,21 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: case_config = task_config.get("case_config") db = DB(task_config.get("db")) - task_config["db_config"] = db.config_cls(**task_config["db_config"]) + # Backward-compat patching for older result files missing required fields + raw_db_cfg = task_config.get("db_config", {}) + if db == DB.WeaviateCloud: + # Ensure minimal fields exist for Weaviate v4 config + raw_db_cfg = dict(raw_db_cfg or {}) + raw_db_cfg.setdefault("url", "http://localhost:8080") + raw_db_cfg.setdefault("api_key", "-") + raw_db_cfg.setdefault("no_auth", True) + elif db == DB.MariaDB: + raw_db_cfg = dict(raw_db_cfg or {}) + raw_db_cfg.setdefault("password", "-") + elif db == DB.PgVector: + raw_db_cfg = dict(raw_db_cfg or {}) + raw_db_cfg.setdefault("password", "-") + task_config["db_config"] = db.config_cls(**raw_db_cfg) # Safely instantiate DBCaseConfig (fallback to EmptyDBCaseConfig on None) raw_case_cfg = task_config.get("db_case_config") or {}