Skip to content

Commit ae71bdb

Browse files
add qps&recall line (#603)
1 parent 245410f commit ae71bdb

File tree

4 files changed

+235
-1
lines changed

4 files changed

+235
-1
lines changed

vectordb_bench/frontend/components/check_results/stPageConfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def initResultsPageConfig(st):
55
st.set_page_config(
66
page_title=PAGE_TITLE,
77
page_icon=FAVICON,
8-
# layout="wide",
8+
layout="wide",
99
# initial_sidebar_state="collapsed",
1010
)
1111

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from vectordb_bench.frontend.components.check_results.expanderStyle import (
2+
initMainExpanderStyle,
3+
)
4+
from vectordb_bench.metric import metric_order, isLowerIsBetterMetric, metric_unit_map
5+
from vectordb_bench.frontend.config.styles import *
6+
import plotly.express as px
7+
import pandas as pd
8+
import plotly.graph_objects as go
9+
import matplotlib.pyplot as plt
10+
11+
12+
def drawCharts(st, allData, caseNames: list[str]):
13+
initMainExpanderStyle(st)
14+
for caseName in caseNames:
15+
chartContainer = st.expander(caseName, True)
16+
data = [data for data in allData if data["case_name"] == caseName]
17+
drawChart(data, chartContainer, key_prefix=caseName)
18+
19+
20+
def drawChart(data, st, key_prefix: str):
21+
metricsSet = set()
22+
for d in data:
23+
metricsSet = metricsSet.union(d["metricsSet"])
24+
showlineMetrics = [metric for metric in metric_order[:2] if metric in metricsSet]
25+
26+
if showlineMetrics:
27+
metric = showlineMetrics[0]
28+
key = f"{key_prefix}-{metric}"
29+
drawlinechart(st, data, metric, key=key)
30+
31+
32+
def drawBestperformance(data, y, group):
33+
all_filter_points = []
34+
data = pd.DataFrame(data)
35+
grouped = data.groupby(group)
36+
for name, group_df in grouped:
37+
filter_points = []
38+
current_start = 0
39+
for _ in range(len(group_df)):
40+
if current_start >= len(group_df):
41+
break
42+
max_index = group_df[y].iloc[current_start:].idxmax()
43+
filter_points.append(group_df.loc[max_index])
44+
45+
current_start = group_df.index.get_loc(max_index) + 1
46+
all_filter_points.extend(filter_points)
47+
48+
all_filter_df = pd.DataFrame(all_filter_points)
49+
remaining_df = data[~data.isin(all_filter_df).any(axis=1)]
50+
new_data = all_filter_df.to_dict(orient="records")
51+
remain_data = remaining_df.to_dict(orient="records")
52+
return new_data, remain_data
53+
54+
55+
def drawlinechart(st, data: list[object], metric, key: str):
56+
unit = metric_unit_map.get(metric, "")
57+
minV = min([d.get(metric, 0) for d in data])
58+
maxV = max([d.get(metric, 0) for d in data])
59+
padding = maxV - minV
60+
rangeV = [
61+
minV - padding * 0.1,
62+
maxV + padding * 0.1,
63+
]
64+
x = "recall"
65+
xrange = [0.8, 1.01]
66+
y = "qps"
67+
yrange = rangeV
68+
data.sort(key=lambda a: a[x])
69+
group = "db_name"
70+
new_data, new_remain_data = drawBestperformance(data, y, group)
71+
unique_db_names = list(set(item["db_name"] for item in new_data + new_remain_data))
72+
73+
colors = plt.cm.get_cmap("tab10", len(unique_db_names))
74+
75+
color_map = {
76+
db: f"rgb({int(colors(i)[0] * 255)}, {int(colors(i)[1] * 255)}, {int(colors(i)[2] * 255)})"
77+
for i, db in enumerate(unique_db_names)
78+
}
79+
80+
fig = go.Figure()
81+
82+
new_data_df = pd.DataFrame(new_data)
83+
84+
for db in unique_db_names:
85+
db_data = new_data_df[new_data_df["db_name"] == db]
86+
fig.add_trace(
87+
go.Scatter(
88+
x=db_data["recall"],
89+
y=db_data["qps"],
90+
mode="lines+markers",
91+
name=db,
92+
line=dict(color=color_map[db]),
93+
marker=dict(color=color_map[db]),
94+
showlegend=True,
95+
)
96+
)
97+
98+
for item in new_remain_data:
99+
fig.add_trace(
100+
go.Scatter(
101+
x=[item["recall"]],
102+
y=[item["qps"]],
103+
mode="markers",
104+
name=item["db_name"],
105+
marker=dict(color=color_map[item["db_name"]]),
106+
showlegend=False,
107+
)
108+
)
109+
110+
fig.update_xaxes(range=xrange)
111+
fig.update_yaxes(range=yrange)
112+
fig.update_traces(textposition="bottom right", texttemplate="%{y:,.4~r}" + unit)
113+
fig.update_layout(
114+
margin=dict(l=0, r=0, t=40, b=0, pad=8),
115+
legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""),
116+
)
117+
st.plotly_chart(fig, use_container_width=True, key=key)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from collections import defaultdict
2+
from dataclasses import asdict
3+
from vectordb_bench.backend.filter import FilterOp
4+
from vectordb_bench.frontend.components.check_results.data import getFilterTasks
5+
from vectordb_bench.frontend.components.check_results.filters import getShowDbsAndCases, getshownResults
6+
from vectordb_bench.models import CaseResult, ResultLabel, TestResult
7+
8+
9+
def getshownData(st, results: list[TestResult], filter_type: FilterOp = FilterOp.NonFilter, **kwargs):
10+
# hide the nav
11+
st.markdown(
12+
"<style> div[data-testid='stSidebarNav'] {display: none;} </style>",
13+
unsafe_allow_html=True,
14+
)
15+
st.header("Filters")
16+
shownResults = getshownResults(st, results, **kwargs)
17+
showDBNames, showCaseNames = getShowDbsAndCases(st, shownResults, filter_type)
18+
shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames)
19+
return shownData, failedTasks, showCaseNames
20+
21+
22+
def getChartData(
23+
tasks: list[CaseResult],
24+
dbNames: list[str],
25+
caseNames: list[str],
26+
):
27+
filterTasks = getFilterTasks(tasks, dbNames, caseNames)
28+
failedTasks = defaultdict(lambda: defaultdict(str))
29+
nonemergedTasks = []
30+
for task in filterTasks:
31+
db_name = task.task_config.db_name
32+
db = task.task_config.db.value
33+
db_label = task.task_config.db_config.db_label or ""
34+
version = task.task_config.db_config.version or ""
35+
case = task.task_config.case_config.case
36+
case_name = case.name
37+
dataset_name = case.dataset.data.full_name
38+
filter_rate = case.filter_rate
39+
metrics = asdict(task.metrics)
40+
label = task.label
41+
if label == ResultLabel.NORMAL:
42+
nonemergedTasks.append(
43+
{
44+
"db_name": db_name,
45+
"db": db,
46+
"db_label": db_label,
47+
"dataset_name": dataset_name,
48+
"filter_rate": filter_rate,
49+
"version": version,
50+
"case_name": case_name,
51+
"metricsSet": set(metrics.keys()),
52+
**metrics,
53+
}
54+
)
55+
else:
56+
failedTasks[case_name][db_name] = label
57+
58+
return nonemergedTasks, failedTasks
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import streamlit as st
2+
from vectordb_bench.frontend.components.check_results.footer import footer
3+
from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
4+
from vectordb_bench.frontend.components.check_results.nav import (
5+
NavToQuriesPerDollar,
6+
NavToRunTest,
7+
NavToPages,
8+
)
9+
from vectordb_bench.frontend.components.qps_recall.charts import drawCharts
10+
from vectordb_bench.frontend.components.qps_recall.data import getshownData
11+
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
12+
13+
from vectordb_bench.frontend.config.styles import FAVICON
14+
from vectordb_bench.interface import benchmark_runner
15+
16+
17+
def main():
18+
# set page config
19+
st.set_page_config(
20+
page_title="Label Filter",
21+
page_icon=FAVICON,
22+
layout="wide",
23+
# initial_sidebar_state="collapsed",
24+
)
25+
26+
# header
27+
drawHeaderIcon(st)
28+
29+
# navigate
30+
NavToPages(st)
31+
32+
allResults = benchmark_runner.get_results()
33+
34+
st.title("Vector Database Benchmark (Qps & Recall)")
35+
36+
# results selector and filter
37+
resultSelectorContainer = st.sidebar.container()
38+
shownData, failedTasks, showCaseNames = getshownData(resultSelectorContainer, allResults)
39+
40+
resultSelectorContainer.divider()
41+
42+
# nav
43+
navContainer = st.sidebar.container()
44+
NavToRunTest(navContainer)
45+
NavToQuriesPerDollar(navContainer)
46+
47+
# save or share
48+
resultesContainer = st.sidebar.container()
49+
getResults(resultesContainer, "vectordb_bench")
50+
51+
# charts
52+
drawCharts(st, shownData, showCaseNames)
53+
54+
# footer
55+
footer(st.container())
56+
57+
58+
if __name__ == "__main__":
59+
main()

0 commit comments

Comments
 (0)