Skip to content

Commit 6de90c5

Browse files
support restful (#690)
Signed-off-by: min.tian <min.tian.cn@gmail.com>
1 parent 2c5b26b commit 6de90c5

File tree

4 files changed

+179
-0
lines changed

4 files changed

+179
-0
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ test = [
5151
"ruff",
5252
"pytest",
5353
]
54+
restful = [ "flask" ]
5455

5556
all = [
5657
"grpcio==1.53.0", # for qdrant-client and pymilvus
@@ -111,6 +112,7 @@ turbopuffer = [ "turbopuffer" ]
111112

112113
[project.scripts]
113114
init_bench = "vectordb_bench.__main__:main"
115+
init_bench_rest = "vectordb_bench.restful.app:main"
114116
vectordbbench = "vectordb_bench.cli.vectordbbench:cli"
115117

116118
[tool.setuptools_scm]

vectordb_bench/restful/__init__.py

Whitespace-only changes.

vectordb_bench/restful/app.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from flask import Flask, jsonify, request
2+
3+
from vectordb_bench.backend.clients import DB
4+
from vectordb_bench.interface import benchmark_runner
5+
from vectordb_bench.models import ALL_TASK_STAGES, CaseConfig, TaskConfig, TaskStage
6+
from vectordb_bench.restful.format_res import format_results
7+
8+
app = Flask(__name__)
9+
10+
11+
def res_wrapper(code: int = 0, message: str = "", data: any = None): # noqa: RUF013
12+
return jsonify({"code": code, "message": message, "data": data}), 200
13+
14+
15+
def success_res(data: any = None, message: str = "success"): # noqa: RUF013
16+
return res_wrapper(code=0, message=message, data=data)
17+
18+
19+
def failed_res(data: any = None, message: str = "failed"): # noqa: RUF013
20+
return res_wrapper(code=1, message=message, data=data)
21+
22+
23+
@app.route("/get_res", methods=["GET"])
24+
def get_res():
25+
"""task label -> res"""
26+
task_label = request.args.get("task_label", "standard")
27+
all_results = benchmark_runner.get_results()
28+
res = format_results(all_results, task_label=task_label)
29+
30+
return success_res(res)
31+
32+
33+
@app.route("/get_status", methods=["GET"])
34+
def get_status():
35+
"running 5/18, not running"
36+
is_running = benchmark_runner.has_running()
37+
tasks_count = benchmark_runner.get_tasks_count()
38+
if is_running:
39+
tasks_count = benchmark_runner.get_tasks_count()
40+
cur_task_idx = benchmark_runner.get_current_task_id()
41+
return success_res(
42+
data={
43+
"is_running": is_running,
44+
"tasks_count": tasks_count,
45+
"cur_task_idx": cur_task_idx,
46+
}
47+
)
48+
return success_res(data={"is_running": is_running})
49+
50+
51+
@app.route("/stop", methods=["GET"])
52+
def stop():
53+
benchmark_runner.stop_running()
54+
return success_res(message="stopped")
55+
56+
57+
@app.route("/run", methods=["post"])
58+
def run():
59+
if benchmark_runner.has_running():
60+
return failed_res(message="There are already running tasks.")
61+
data = request.get_json()
62+
task_label = data.get("task_label", "test")
63+
use_aliyun = data.get("use_aliyun", False)
64+
task_configs: list[TaskConfig] = []
65+
try:
66+
tasks = data.get("tasks", [])
67+
if len(tasks) == 0:
68+
return failed_res(message="empty tasks")
69+
for task in tasks:
70+
db = DB(task["db"])
71+
db_config = db.config_cls(**task["db_config"])
72+
case_config = CaseConfig(**task["case_config"])
73+
print(case_config) # noqa: T201
74+
db_case_config = db.case_config_cls(index_type=task["db_case_config"].get("index", None))(
75+
**task["db_case_config"]
76+
)
77+
stages = [TaskStage(stage) for stage in task.get("stages", ALL_TASK_STAGES)]
78+
print(stages) # noqa: T201
79+
task_config = TaskConfig(
80+
db=db,
81+
db_config=db_config,
82+
case_config=case_config,
83+
db_case_config=db_case_config,
84+
stages=stages,
85+
)
86+
task_configs.append(task_config)
87+
except Exception as e:
88+
return failed_res(message=f"invalid tasks: {e}")
89+
90+
benchmark_runner.set_download_address(use_aliyun)
91+
benchmark_runner.run(task_configs, task_label)
92+
93+
return success_res(message="start")
94+
95+
96+
def main():
97+
app.run(host="0.0.0.0", port=5000, debug=False) # noqa: S104
98+
99+
100+
if __name__ == "__main__":
101+
main()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from dataclasses import asdict
2+
3+
from pydantic import BaseModel
4+
5+
from vectordb_bench.backend.cases import CaseLabel
6+
from vectordb_bench.models import TestResult
7+
8+
9+
class FormatResult(BaseModel):
10+
# db_config
11+
task_label: str = ""
12+
timestamp: int = 0
13+
db: str = ""
14+
db_label: str = "" # perf-x86
15+
version: str = ""
16+
note: str = ""
17+
18+
# params
19+
params: dict = {}
20+
21+
# case_config
22+
case_name: str = ""
23+
dataset: str = ""
24+
dim: int = 0
25+
filter_type: str = "" # FilterType(Enum).value
26+
filter_rate: float = 0
27+
k: int = 100
28+
29+
# metrics
30+
max_load_count: int = 0
31+
load_duration: int = 0
32+
qps: float = 0
33+
serial_latency_p99: float = 0
34+
recall: float = 0
35+
ndcg: float = 0
36+
conc_num_list: list[int] = []
37+
conc_qps_list: list[float] = []
38+
conc_latency_p99_list: list[float] = []
39+
conc_latency_avg_list: list[float] = []
40+
41+
42+
def format_results(test_results: list[TestResult], task_label: str) -> list[dict]:
43+
results = []
44+
for test_result in test_results:
45+
if test_result.task_label == task_label:
46+
for case_result in test_result.results:
47+
task_config = case_result.task_config
48+
case_config = task_config.case_config
49+
case = case_config.case
50+
if case.label == CaseLabel.Load:
51+
continue
52+
dataset = case.dataset.data
53+
filter_ = case.filters
54+
metrics = asdict(case_result.metrics)
55+
for k, v in metrics.items():
56+
if isinstance(v, list) and len(v) > 0:
57+
metrics[k] = [round(d, 6) if isinstance(d, float) else d for d in v]
58+
results.append(
59+
FormatResult(
60+
task_label=test_result.task_label,
61+
timestamp=int(test_result.timestamp),
62+
db=task_config.db.value,
63+
db_label=task_config.db_config.db_label,
64+
version=task_config.db_config.version,
65+
note=task_config.db_config.note,
66+
params=task_config.db_case_config.dict(),
67+
case_name=case.name,
68+
dataset=dataset.full_name,
69+
dim=dataset.dim,
70+
filter_type=filter_.type.name,
71+
filter_rate=filter_.filter_rate,
72+
k=task_config.case_config.k,
73+
**metrics,
74+
).dict()
75+
)
76+
return results

0 commit comments

Comments
 (0)