Skip to content

Commit a5095bf

Browse files
committed
batch cli
1 parent eefd398 commit a5095bf

File tree

2 files changed

+152
-23
lines changed

2 files changed

+152
-23
lines changed

vectordb_bench/cli/batch_cli.py

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import time
32
from collections.abc import MutableMapping
43
from concurrent.futures import wait
54
from pathlib import Path
@@ -14,6 +13,7 @@
1413
cli,
1514
click_parameter_decorators_from_typed_dict,
1615
)
16+
from ..models import TaskConfig
1717

1818
log = logging.getLogger(__name__)
1919

@@ -95,27 +95,136 @@ def format_bool_option(opt_name: str, value: Any, skip: bool = False, raw_key: s
9595
return args_arr
9696

9797

98-
@cli.command()
99-
@click_parameter_decorators_from_typed_dict(BatchCliTypedDict)
100-
def BatchCli():
101-
ctx = click.get_current_context()
102-
batch_config = ctx.default_map
98+
def build_task_from_config(cmd_name: str, config_dict: dict[str, Any]) -> TaskConfig | None:
10399

104-
runner = CliRunner()
100+
collected_tasks = []
101+
original_run = None
105102

106-
args_arr = build_sub_cmd_args(batch_config)
103+
try:
104+
from ..interface import benchmark_runner
105+
106+
original_run = benchmark_runner.run
107+
108+
def collect_task_wrapper(tasks: list[TaskConfig], task_label: str | None = None): # noqa: ARG001
109+
collected_tasks.extend(tasks)
110+
return True
111+
112+
benchmark_runner.run = collect_task_wrapper
113+
114+
# build CLI parameters
115+
args = [cmd_name]
116+
bool_options = {
117+
"drop_old": True,
118+
"load": True,
119+
"search_serial": True,
120+
"search_concurrent": True,
121+
"dry_run": False,
122+
"custom_dataset_use_shuffled": True,
123+
"custom_dataset_with_gt": True,
124+
}
125+
126+
def format_option(key: str, value: Any):
127+
opt_name = key.replace("_", "-")
128+
129+
if key in bool_options:
130+
return format_bool_option(opt_name, value, skip=False)
131+
132+
if key.startswith("skip_"):
133+
raw_key = key[5:]
134+
raw_opt = raw_key.replace("_", "-")
135+
return format_bool_option(raw_opt, value, skip=True, raw_key=raw_key)
136+
137+
return [f"--{opt_name}", str(value)]
138+
139+
def format_bool_option(opt_name: str, value: Any, skip: bool = False, raw_key: str | None = None):
140+
if isinstance(value, bool):
141+
if skip:
142+
if bool_options.get(raw_key, False):
143+
return [f"--skip-{opt_name}"] if value else [f"--{opt_name}"]
144+
return [f"--{opt_name}", str(value)]
145+
if value:
146+
return [f"--{opt_name}"]
147+
if bool_options.get(opt_name.replace("-", "_"), False):
148+
return [f"--skip-{opt_name}"]
149+
return []
150+
return [f"--{opt_name}", str(value)]
151+
152+
for k, v in config_dict.items():
153+
args.extend(format_option(k, v))
154+
155+
# call CLI command (this will trigger collect_task_wrapper)
156+
runner = CliRunner()
157+
result = runner.invoke(cli, args, catch_exceptions=False)
107158

108-
for args in args_arr:
109-
log.info(f"got batch config: {' '.join(args)}")
159+
if result.exception:
160+
log.error(f"Failed to build task for {cmd_name}: {result.exception}")
161+
return None
110162

111-
for args in args_arr:
112-
result = runner.invoke(cli, args)
113-
time.sleep(5)
163+
if collected_tasks:
164+
return collected_tasks[0]
165+
return None # noqa: TRY300
114166

115-
from ..interface import global_result_future
167+
except Exception:
168+
log.exception("Error building task from config")
169+
return None
170+
finally:
171+
if original_run is not None:
172+
from ..interface import benchmark_runner
116173

117-
if global_result_future:
118-
wait([global_result_future])
174+
benchmark_runner.run = original_run
119175

120-
if result.exception:
121-
log.exception(f"failed to run sub command: {args[0]}", exc_info=result.exception)
176+
177+
@cli.command()
178+
@click_parameter_decorators_from_typed_dict(BatchCliTypedDict)
179+
def BatchCli():
180+
ctx = click.get_current_context()
181+
batch_config = ctx.default_map
182+
183+
from ..interface import benchmark_runner, global_result_future
184+
185+
# collect all tasks
186+
all_tasks: list[TaskConfig] = []
187+
task_labels: set[str] = set()
188+
189+
for cmd_name, cmd_config_list in batch_config.items():
190+
for config_dict in cmd_config_list:
191+
log.info(f"Building task for {cmd_name} with config: {config_dict.get('task_label', 'N/A')}")
192+
193+
# collect task_label from config
194+
if "task_label" in config_dict:
195+
task_labels.add(config_dict["task_label"])
196+
197+
# TaskConfig
198+
task = build_task_from_config(cmd_name, config_dict)
199+
if task:
200+
all_tasks.append(task)
201+
log.info(f"Successfully built task: {task.db.value} - {task.case_config.case_id.name}")
202+
else:
203+
log.warning(f"Failed to build task for {cmd_name}")
204+
205+
if not all_tasks:
206+
log.error("No tasks were built from the batch config file")
207+
return
208+
209+
if len(task_labels) == 1:
210+
task_label = task_labels.pop()
211+
log.info(f"Using shared task_label from config: {task_label}")
212+
elif len(task_labels) > 1:
213+
task_label = next(iter(task_labels))
214+
log.warning(f"Multiple task_labels found in config, using the first one: {task_label}")
215+
else:
216+
from datetime import datetime
217+
218+
task_label = f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
219+
log.info(f"No task_label found in config, using generated one: {task_label}")
220+
221+
log.info(f"Running {len(all_tasks)} tasks with shared task_label: {task_label}")
222+
223+
benchmark_runner.run(all_tasks, task_label)
224+
225+
if global_result_future:
226+
log.info("Waiting for all tasks to complete...")
227+
wait([global_result_future])
228+
log.info("All tasks completed successfully")
229+
else:
230+
log.warning("No global_result_future found, tasks may be running in background")

vectordb_bench/cli/cli.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -608,22 +608,24 @@ class OceanBaseIVFTypedDict(TypedDict):
608608
def cli(): ...
609609

610610

611-
def run(
611+
def build_task(
612612
db: DB,
613613
db_config: DBConfig,
614614
db_case_config: DBCaseConfig,
615615
**parameters: Unpack[CommonTypedDict],
616-
):
617-
"""Builds a single VectorDBBench Task and runs it, awaiting the task until finished.
616+
) -> TaskConfig:
617+
"""Builds a single VectorDBBench Task without running it.
618618
619619
Args:
620620
db (DB)
621621
db_config (DBConfig)
622622
db_case_config (DBCaseConfig)
623623
**parameters: expects keys from CommonTypedDict
624-
"""
625624
626-
task = TaskConfig(
625+
Returns:
626+
TaskConfig: The created task configuration
627+
"""
628+
return TaskConfig(
627629
db=db,
628630
db_config=db_config,
629631
db_case_config=db_case_config,
@@ -644,6 +646,24 @@ def run(
644646
parameters["search_concurrent"],
645647
),
646648
)
649+
650+
651+
def run(
652+
db: DB,
653+
db_config: DBConfig,
654+
db_case_config: DBCaseConfig,
655+
**parameters: Unpack[CommonTypedDict],
656+
):
657+
"""Builds a single VectorDBBench Task and runs it, awaiting the task until finished.
658+
659+
Args:
660+
db (DB)
661+
db_config (DBConfig)
662+
db_case_config (DBCaseConfig)
663+
**parameters: expects keys from CommonTypedDict
664+
"""
665+
666+
task = build_task(db, db_config, db_case_config, **parameters)
647667
task_label = parameters["task_label"]
648668

649669
log.info(f"Task:\n{pformat(task)}\n")

0 commit comments

Comments
 (0)