Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 127 additions & 18 deletions vectordb_bench/cli/batch_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import time
from collections.abc import MutableMapping
from concurrent.futures import wait
from pathlib import Path
Expand All @@ -14,6 +13,7 @@
cli,
click_parameter_decorators_from_typed_dict,
)
from ..models import TaskConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,27 +95,136 @@ def format_bool_option(opt_name: str, value: Any, skip: bool = False, raw_key: s
return args_arr


@cli.command()
@click_parameter_decorators_from_typed_dict(BatchCliTypedDict)
def BatchCli():
ctx = click.get_current_context()
batch_config = ctx.default_map
def build_task_from_config(cmd_name: str, config_dict: dict[str, Any]) -> TaskConfig | None:

runner = CliRunner()
collected_tasks = []
original_run = None

args_arr = build_sub_cmd_args(batch_config)
try:
from ..interface import benchmark_runner

original_run = benchmark_runner.run

def collect_task_wrapper(tasks: list[TaskConfig], task_label: str | None = None): # noqa: ARG001
collected_tasks.extend(tasks)
return True

benchmark_runner.run = collect_task_wrapper

# build CLI parameters
args = [cmd_name]
bool_options = {
"drop_old": True,
"load": True,
"search_serial": True,
"search_concurrent": True,
"dry_run": False,
"custom_dataset_use_shuffled": True,
"custom_dataset_with_gt": True,
}

def format_option(key: str, value: Any):
opt_name = key.replace("_", "-")

if key in bool_options:
return format_bool_option(opt_name, value, skip=False)

if key.startswith("skip_"):
raw_key = key[5:]
raw_opt = raw_key.replace("_", "-")
return format_bool_option(raw_opt, value, skip=True, raw_key=raw_key)

return [f"--{opt_name}", str(value)]

def format_bool_option(opt_name: str, value: Any, skip: bool = False, raw_key: str | None = None):
if isinstance(value, bool):
if skip:
if bool_options.get(raw_key, False):
return [f"--skip-{opt_name}"] if value else [f"--{opt_name}"]
return [f"--{opt_name}", str(value)]
if value:
return [f"--{opt_name}"]
if bool_options.get(opt_name.replace("-", "_"), False):
return [f"--skip-{opt_name}"]
return []
return [f"--{opt_name}", str(value)]

for k, v in config_dict.items():
args.extend(format_option(k, v))

# call CLI command (this will trigger collect_task_wrapper)
runner = CliRunner()
result = runner.invoke(cli, args, catch_exceptions=False)

for args in args_arr:
log.info(f"got batch config: {' '.join(args)}")
if result.exception:
log.error(f"Failed to build task for {cmd_name}: {result.exception}")
return None

for args in args_arr:
result = runner.invoke(cli, args)
time.sleep(5)
if collected_tasks:
return collected_tasks[0]
return None # noqa: TRY300

from ..interface import global_result_future
except Exception:
log.exception("Error building task from config")
return None
finally:
if original_run is not None:
from ..interface import benchmark_runner

if global_result_future:
wait([global_result_future])
benchmark_runner.run = original_run

if result.exception:
log.exception(f"failed to run sub command: {args[0]}", exc_info=result.exception)

@cli.command()
@click_parameter_decorators_from_typed_dict(BatchCliTypedDict)
def BatchCli():
ctx = click.get_current_context()
batch_config = ctx.default_map

from ..interface import benchmark_runner, global_result_future

# collect all tasks
all_tasks: list[TaskConfig] = []
task_labels: set[str] = set()

for cmd_name, cmd_config_list in batch_config.items():
for config_dict in cmd_config_list:
log.info(f"Building task for {cmd_name} with config: {config_dict.get('task_label', 'N/A')}")

# collect task_label from config
if "task_label" in config_dict:
task_labels.add(config_dict["task_label"])

# TaskConfig
task = build_task_from_config(cmd_name, config_dict)
if task:
all_tasks.append(task)
log.info(f"Successfully built task: {task.db.value} - {task.case_config.case_id.name}")
else:
log.warning(f"Failed to build task for {cmd_name}")

if not all_tasks:
log.error("No tasks were built from the batch config file")
return

if len(task_labels) == 1:
task_label = task_labels.pop()
log.info(f"Using shared task_label from config: {task_label}")
elif len(task_labels) > 1:
task_label = next(iter(task_labels))
log.warning(f"Multiple task_labels found in config, using the first one: {task_label}")
else:
from datetime import datetime

task_label = f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
log.info(f"No task_label found in config, using generated one: {task_label}")

log.info(f"Running {len(all_tasks)} tasks with shared task_label: {task_label}")

benchmark_runner.run(all_tasks, task_label)

if global_result_future:
log.info("Waiting for all tasks to complete...")
wait([global_result_future])
log.info("All tasks completed successfully")
else:
log.warning("No global_result_future found, tasks may be running in background")
30 changes: 25 additions & 5 deletions vectordb_bench/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,22 +608,24 @@ class OceanBaseIVFTypedDict(TypedDict):
def cli(): ...


def run(
def build_task(
db: DB,
db_config: DBConfig,
db_case_config: DBCaseConfig,
**parameters: Unpack[CommonTypedDict],
):
"""Builds a single VectorDBBench Task and runs it, awaiting the task until finished.
) -> TaskConfig:
"""Builds a single VectorDBBench Task without running it.

Args:
db (DB)
db_config (DBConfig)
db_case_config (DBCaseConfig)
**parameters: expects keys from CommonTypedDict
"""

task = TaskConfig(
Returns:
TaskConfig: The created task configuration
"""
return TaskConfig(
db=db,
db_config=db_config,
db_case_config=db_case_config,
Expand All @@ -644,6 +646,24 @@ def run(
parameters["search_concurrent"],
),
)


def run(
db: DB,
db_config: DBConfig,
db_case_config: DBCaseConfig,
**parameters: Unpack[CommonTypedDict],
):
"""Builds a single VectorDBBench Task and runs it, awaiting the task until finished.

Args:
db (DB)
db_config (DBConfig)
db_case_config (DBCaseConfig)
**parameters: expects keys from CommonTypedDict
"""

task = build_task(db, db_config, db_case_config, **parameters)
task_label = parameters["task_label"]

log.info(f"Task:\n{pformat(task)}\n")
Expand Down