Skip to content

Commit 45d1ff5

Browse files
feat: Add Custom Dataset support for Streaming Tests (#660)
1 parent f56df66 commit 45d1ff5

File tree

7 files changed

+286
-3
lines changed

7 files changed

+286
-3
lines changed

vectordb_bench/backend/cases.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class CaseType(Enum):
5151
PerformanceCustomDataset = 101
5252

5353
StreamingPerformanceCase = 200
54+
StreamingCustomDataset = 201
5455

5556
LabelFilterPerformanceCase = 300
5657

@@ -474,6 +475,85 @@ def __init__(
474475
)
475476

476477

478+
class StreamingCustomDataset(Case):
479+
case_id: CaseType = CaseType.StreamingCustomDataset
480+
label: CaseLabel = CaseLabel.Streaming
481+
name: str = "Streaming Performance With Custom Dataset"
482+
description: str = ""
483+
dataset: DatasetManager
484+
insert_rate: int
485+
search_stages: list[float]
486+
concurrencies: list[int]
487+
optimize_after_write: bool = True
488+
read_dur_after_write: int = 30
489+
490+
def __init__(
491+
self,
492+
description: str,
493+
dataset_config: dict,
494+
insert_rate: int = 500,
495+
search_stages: list[float] | str = (0.5, 0.8),
496+
concurrencies: list[int] | str = (5, 10),
497+
optimize_after_write: bool = True,
498+
read_dur_after_write: int = 30,
499+
**kwargs,
500+
):
501+
num_per_batch = config.NUM_PER_BATCH
502+
if insert_rate % config.NUM_PER_BATCH != 0:
503+
_insert_rate = max(
504+
num_per_batch,
505+
insert_rate // num_per_batch * num_per_batch,
506+
)
507+
log.warning(
508+
f"[streaming_case init] insert_rate(={insert_rate}) should be "
509+
f"divisible by NUM_PER_BATCH={num_per_batch}), reset to {_insert_rate}",
510+
)
511+
insert_rate = _insert_rate
512+
513+
dataset_config = CustomDatasetConfig(**dataset_config)
514+
dataset = CustomDataset(
515+
name=dataset_config.name,
516+
size=dataset_config.size,
517+
dim=dataset_config.dim,
518+
metric_type=metric_type_map(dataset_config.metric_type),
519+
use_shuffled=dataset_config.use_shuffled,
520+
with_gt=dataset_config.with_gt,
521+
dir=dataset_config.dir,
522+
file_num=dataset_config.file_count,
523+
train_file=dataset_config.train_name,
524+
test_file=f"{dataset_config.test_name}.parquet",
525+
train_id_field=dataset_config.train_id_name,
526+
train_vector_field=dataset_config.train_col_name,
527+
test_vector_field=dataset_config.test_col_name,
528+
gt_neighbors_field=dataset_config.gt_col_name,
529+
scalar_labels_file=f"{dataset_config.scalar_labels_name}.parquet",
530+
)
531+
name = f"Streaming-Perf - Custom - {dataset_config.name}, {insert_rate} rows/s"
532+
description = (
533+
description
534+
if description
535+
else f"This case tests the search performance of vector database while maintaining "
536+
f"a fixed insertion speed. (dataset: Custom - {dataset_config.name})"
537+
)
538+
539+
if isinstance(search_stages, str):
540+
search_stages = json.loads(search_stages)
541+
if isinstance(concurrencies, str):
542+
concurrencies = json.loads(concurrencies)
543+
544+
super().__init__(
545+
name=name,
546+
description=description,
547+
dataset=DatasetManager(data=dataset),
548+
insert_rate=insert_rate,
549+
search_stages=search_stages,
550+
concurrencies=concurrencies,
551+
optimize_after_write=optimize_after_write,
552+
read_dur_after_write=read_dur_after_write,
553+
**kwargs,
554+
)
555+
556+
477557
class NewIntFilterPerformanceCase(PerformanceCase):
478558
case_id: CaseType = CaseType.NewIntFilterPerformanceCase
479559
dataset_with_size_type: DatasetWithSizeType
@@ -572,6 +652,7 @@ def filters(self) -> Filter:
572652
CaseType.Performance1536D50K: Performance1536D50K,
573653
CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
574654
CaseType.StreamingPerformanceCase: StreamingPerformanceCase,
655+
CaseType.StreamingCustomDataset: StreamingCustomDataset,
575656
CaseType.NewIntFilterPerformanceCase: NewIntFilterPerformanceCase,
576657
CaseType.LabelFilterPerformanceCase: LabelFilterPerformanceCase,
577658
}

vectordb_bench/custom/custom_case.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,18 @@
1414
"use_shuffled": false,
1515
"with_gt": true
1616
}
17+
},
18+
{
19+
"case_type": "streaming",
20+
"description": "This is a custom streaming dataset.",
21+
"dataset_config": {
22+
"name": "My Streaming Dataset",
23+
"dir": "/my_dataset_path",
24+
"size": 1000000,
25+
"dim": 1024,
26+
"file_count": 1,
27+
"train_name": "shuffle_train",
28+
"with_gt": true
29+
}
1730
}
1831
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from vectordb_bench.frontend.components.custom.getCustomConfig import CustomStreamingCaseConfig
2+
3+
4+
def displayCustomStreamingCase(streamingCase: CustomStreamingCaseConfig, st, key):
5+
6+
columns = st.columns([1, 2])
7+
streamingCase.dataset_config.name = columns[0].text_input(
8+
"Name", key=f"{key}_name", value=streamingCase.dataset_config.name
9+
)
10+
streamingCase.dataset_config.dir = columns[1].text_input(
11+
"Folder Path", key=f"{key}_dir", value=streamingCase.dataset_config.dir
12+
)
13+
14+
columns = st.columns(2)
15+
streamingCase.dataset_config.dim = columns[0].number_input(
16+
"dim", key=f"{key}_dim", value=streamingCase.dataset_config.dim
17+
)
18+
streamingCase.dataset_config.size = columns[1].number_input(
19+
"size", key=f"{key}_size", value=streamingCase.dataset_config.size
20+
)
21+
22+
columns = st.columns(3)
23+
streamingCase.dataset_config.train_name = columns[0].text_input(
24+
"train file name",
25+
key=f"{key}_train_name",
26+
value=streamingCase.dataset_config.train_name,
27+
)
28+
streamingCase.dataset_config.test_name = columns[1].text_input(
29+
"test file name", key=f"{key}_test_name", value=streamingCase.dataset_config.test_name
30+
)
31+
streamingCase.dataset_config.gt_name = columns[2].text_input(
32+
"ground truth file name", key=f"{key}_gt_name", value=streamingCase.dataset_config.gt_name
33+
)
34+
35+
columns = st.columns([1, 1, 2, 2])
36+
streamingCase.dataset_config.train_id_name = columns[0].text_input(
37+
"train id name", key=f"{key}_train_id_name", value=streamingCase.dataset_config.train_id_name
38+
)
39+
streamingCase.dataset_config.train_col_name = columns[1].text_input(
40+
"train emb name", key=f"{key}_train_col_name", value=streamingCase.dataset_config.train_col_name
41+
)
42+
streamingCase.dataset_config.test_col_name = columns[2].text_input(
43+
"test emb name", key=f"{key}_test_col_name", value=streamingCase.dataset_config.test_col_name
44+
)
45+
streamingCase.dataset_config.gt_col_name = columns[3].text_input(
46+
"ground truth emb name", key=f"{key}_gt_col_name", value=streamingCase.dataset_config.gt_col_name
47+
)
48+
49+
streamingCase.description = st.text_area("description", key=f"{key}_description", value=streamingCase.description)

vectordb_bench/frontend/components/custom/getCustomConfig.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,49 @@ class CustomCaseConfig(BaseModel):
3434
dataset_config: CustomDatasetConfig = CustomDatasetConfig()
3535

3636

37+
class CustomStreamingCaseConfig(BaseModel):
38+
case_type: str = "streaming"
39+
description: str = ""
40+
dataset_config: CustomDatasetConfig = CustomDatasetConfig()
41+
42+
3743
def get_custom_configs():
3844
with open(config.CUSTOM_CONFIG_DIR, "r") as f:
3945
custom_configs = json.load(f)
40-
return [CustomCaseConfig(**custom_config) for custom_config in custom_configs]
46+
return [
47+
CustomCaseConfig(**custom_config)
48+
for custom_config in custom_configs
49+
if custom_config.get("case_type") != "streaming"
50+
]
51+
52+
53+
def get_custom_streaming_configs():
54+
with open(config.CUSTOM_CONFIG_DIR, "r") as f:
55+
custom_configs = json.load(f)
56+
return [
57+
CustomStreamingCaseConfig(**custom_config)
58+
for custom_config in custom_configs
59+
if custom_config.get("case_type") == "streaming"
60+
]
4161

4262

4363
def save_custom_configs(custom_configs: list[CustomDatasetConfig]):
4464
with open(config.CUSTOM_CONFIG_DIR, "w") as f:
4565
json.dump([custom_config.dict() for custom_config in custom_configs], f, indent=4)
4666

4767

68+
def save_all_custom_configs(
69+
performance_configs: list[CustomCaseConfig], streaming_configs: list[CustomStreamingCaseConfig]
70+
):
71+
"""Save both performance and streaming configs to the same JSON file"""
72+
all_configs = [config.dict() for config in performance_configs] + [config.dict() for config in streaming_configs]
73+
with open(config.CUSTOM_CONFIG_DIR, "w") as f:
74+
json.dump(all_configs, f, indent=4)
75+
76+
4877
def generate_custom_case():
4978
return CustomCaseConfig()
79+
80+
81+
def generate_custom_streaming_case():
82+
return CustomStreamingCaseConfig()

vectordb_bench/frontend/components/run_test/caseSelector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
UICaseItemCluster,
88
get_case_config_inputs,
99
get_custom_case_cluter,
10+
get_custom_streaming_case_cluster,
1011
)
1112
from vectordb_bench.frontend.config.styles import (
1213
CASE_CONFIG_SETTING_COLUMNS,
@@ -32,7 +33,7 @@ def caseSelector(st, activedDbList: list[DB]):
3233
activedCaseList: list[CaseConfig] = []
3334
dbToCaseClusterConfigs = defaultdict(lambda: defaultdict(dict))
3435
dbToCaseConfigs = defaultdict(lambda: defaultdict(dict))
35-
caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()]
36+
caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter(), get_custom_streaming_case_cluster()]
3637
for caseCluster in caseClusters:
3738
activedCaseList += caseClusterExpander(st, caseCluster, dbToCaseClusterConfigs, activedDbList)
3839
for db in dbToCaseClusterConfigs:

vectordb_bench/frontend/config/dbCaseConfigs.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,34 @@ def get_custom_case_cluter() -> UICaseItemCluster:
161161
return UICaseItemCluster(label="Custom Search Performance Test", uiCaseItems=get_custom_case_items())
162162

163163

164+
def get_custom_streaming_case_items() -> list[UICaseItem]:
165+
from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_streaming_configs
166+
167+
custom_streaming_configs = get_custom_streaming_configs()
168+
return [
169+
UICaseItem(
170+
label=f"{custom_config.dataset_config.name} - Streaming",
171+
description=f"Streaming test with custom dataset: {custom_config.dataset_config.name}",
172+
cases=[
173+
CaseConfig(
174+
case_id=CaseType.StreamingCustomDataset,
175+
custom_case={
176+
"description": custom_config.description,
177+
"dataset_config": custom_config.dataset_config.dict(),
178+
},
179+
)
180+
],
181+
caseLabel=CaseLabel.Streaming,
182+
extra_custom_case_config_inputs=custom_streaming_config_with_custom_dataset,
183+
)
184+
for custom_config in custom_streaming_configs
185+
]
186+
187+
188+
def get_custom_streaming_case_cluster() -> UICaseItemCluster:
189+
return UICaseItemCluster(label="Custom Streaming Test", uiCaseItems=get_custom_streaming_case_items())
190+
191+
164192
def generate_custom_streaming_case() -> CaseConfig:
165193
return CaseConfig(
166194
case_id=CaseType.StreamingPerformanceCase,
@@ -207,6 +235,12 @@ def generate_custom_streaming_case() -> CaseConfig:
207235
),
208236
]
209237

238+
# Config for custom streaming tests (with custom dataset from JSON)
239+
# Filter out the dataset_with_size_type from the existing config
240+
custom_streaming_config_with_custom_dataset: list[ConfigInput] = [
241+
config for config in custom_streaming_config if config.label != CaseConfigParamType.dataset_with_size_type
242+
]
243+
210244

211245
def generate_label_filter_cases(dataset_with_size_type: DatasetWithSizeType) -> list[CaseConfig]:
212246
label_percentages = dataset_with_size_type.get_manager().data.scalar_label_percentages

vectordb_bench/frontend/pages/custom.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
from vectordb_bench.frontend.components.custom.displayCustomCase import (
66
displayCustomCase,
77
)
8+
from vectordb_bench.frontend.components.custom.displayCustomStreamingCase import (
9+
displayCustomStreamingCase,
10+
)
811
from vectordb_bench.frontend.components.custom.displaypPrams import displayParams
912
from vectordb_bench.frontend.components.custom.getCustomConfig import (
1013
CustomCaseConfig,
14+
CustomStreamingCaseConfig,
1115
generate_custom_case,
16+
generate_custom_streaming_case,
1217
get_custom_configs,
18+
get_custom_streaming_configs,
1319
save_custom_configs,
20+
save_all_custom_configs,
1421
)
1522
from vectordb_bench.frontend.components.custom.initStyle import initStyle
1623
from vectordb_bench.frontend.config.styles import FAVICON, PAGE_TITLE
@@ -33,7 +40,33 @@ def deleteCase(self, idx: int):
3340
self.save()
3441

3542
def save(self):
36-
save_custom_configs(self.customCaseItems)
43+
# Save performance configs along with existing streaming configs
44+
streaming_configs = get_custom_streaming_configs()
45+
save_all_custom_configs(self.customCaseItems, streaming_configs)
46+
47+
48+
class StreamingCaseManager:
49+
streamingCaseItems: list[CustomStreamingCaseConfig]
50+
51+
def __init__(self):
52+
self.streamingCaseItems = get_custom_streaming_configs()
53+
54+
def addCase(self):
55+
new_streaming_case = generate_custom_streaming_case()
56+
new_streaming_case.dataset_config.name = (
57+
f"{new_streaming_case.dataset_config.name} {len(self.streamingCaseItems)}"
58+
)
59+
self.streamingCaseItems += [new_streaming_case]
60+
self.save()
61+
62+
def deleteCase(self, idx: int):
63+
self.streamingCaseItems.pop(idx)
64+
self.save()
65+
66+
def save(self):
67+
# Save streaming configs along with existing performance configs
68+
performance_configs = get_custom_configs()
69+
save_all_custom_configs(performance_configs, self.streamingCaseItems)
3770

3871

3972
def main():
@@ -55,6 +88,11 @@ def main():
5588

5689
st.title("Custom Dataset")
5790
displayParams(st)
91+
92+
# Performance Test Datasets Section
93+
st.subheader("Performance Test Datasets")
94+
st.markdown("These datasets are used for search performance tests.")
95+
5896
customCaseManager = CustomCaseManager()
5997

6098
for idx, customCase in enumerate(customCaseManager.customCaseItems):
@@ -84,6 +122,40 @@ def main():
84122
on_click=lambda: customCaseManager.addCase(),
85123
)
86124

125+
st.divider()
126+
127+
# Streaming Test Datasets Section
128+
st.subheader("Streaming Test Datasets")
129+
st.markdown("These datasets are used for streaming performance tests (insertion + search).")
130+
131+
streamingCaseManager = StreamingCaseManager()
132+
133+
for idx, streamingCase in enumerate(streamingCaseManager.streamingCaseItems):
134+
expander = st.expander(streamingCase.dataset_config.name, expanded=True)
135+
key = f"streaming_case_{idx}"
136+
displayCustomStreamingCase(streamingCase, expander, key=key)
137+
138+
columns = expander.columns(8)
139+
columns[0].button(
140+
"Save",
141+
key=f"{key}_save",
142+
type="secondary",
143+
on_click=lambda: streamingCaseManager.save(),
144+
)
145+
columns[1].button(
146+
":red[Delete]",
147+
key=f"{key}_delete",
148+
type="secondary",
149+
on_click=partial(lambda idx: streamingCaseManager.deleteCase(idx), idx=idx),
150+
)
151+
152+
st.button(
153+
"+ New Streaming Dataset",
154+
key="add_streaming_config",
155+
type="primary",
156+
on_click=lambda: streamingCaseManager.addCase(),
157+
)
158+
87159

88160
if __name__ == "__main__":
89161
main()

0 commit comments

Comments
 (0)