@@ -11,7 +11,8 @@ class TestVectorIndex(RollingUpgradeAndDowngradeFixture):
11
11
def setup (self ):
12
12
if min (self .versions ) < (25 , 1 ):
13
13
pytest .skip ("Only available since 25-1" )
14
- self .rows_count = 5
14
+ self .rows_count = 9
15
+ self .rows_per_user = 3
15
16
self .index_name = "vector_idx"
16
17
self .vector_dimension = 3
17
18
self .vector_types = {
@@ -40,31 +41,21 @@ def get_vector(self, type, numb):
40
41
values .append (numb )
41
42
return "," .join (str (val ) for val in values )
42
43
43
- def _create_index (self , vector_type , table_name , distance = None , similarity = None ):
44
- if distance is not None :
45
- create_index_sql = f"""
46
- ALTER TABLE { table_name }
47
- ADD INDEX `{ self .index_name } ` GLOBAL USING vector_kmeans_tree
48
- ON (vec)
49
- WITH (distance={ distance } ,
50
- vector_type={ vector_type } ,
51
- vector_dimension={ self .vector_dimension } ,
52
- levels=2,
53
- clusters=10
54
- );
55
- """
56
- else :
57
- create_index_sql = f"""
58
- ALTER TABLE { table_name }
59
- ADD INDEX `{ self .index_name } ` GLOBAL USING vector_kmeans_tree
60
- ON (vec)
61
- WITH (similarity={ similarity } ,
62
- vector_type={ vector_type } ,
63
- vector_dimension={ self .vector_dimension } ,
64
- levels=2,
65
- clusters=10
66
- );
67
- """
44
+ def _create_index (self , vector_type , table_name , target , prefixed ):
45
+ prefix = ""
46
+ if prefixed :
47
+ prefix = "user, "
48
+ create_index_sql = f"""
49
+ ALTER TABLE { table_name }
50
+ ADD INDEX `{ self .index_name } ` GLOBAL USING vector_kmeans_tree
51
+ ON ({ prefix } vec)
52
+ WITH ({ target } ,
53
+ vector_type={ vector_type } ,
54
+ vector_dimension={ self .vector_dimension } ,
55
+ levels=2,
56
+ clusters=10
57
+ );
58
+ """
68
59
with ydb .QuerySessionPool (self .driver ) as session_pool :
69
60
session_pool .execute_with_retries (create_index_sql )
70
61
@@ -80,74 +71,99 @@ def predicate():
80
71
81
72
assert wait_for (predicate , timeout_seconds = 100 , step_seconds = 1 ), "Error getting index status"
82
73
83
- def write_data (self , name , vector_type , table_name ):
74
+ def _write_data (self , name , vector_type , table_name ):
84
75
values = []
85
76
for key in range (self .rows_count ):
86
77
vector = self .get_vector (vector_type , key + 1 )
87
- values .append (f'({ key } , Untag({ name } ([{ vector } ]), "{ vector_type } "))' )
78
+ user = 1 + (key % self .rows_per_user )
79
+ values .append (f'({ key } , { user } , Untag({ name } ([{ vector } ]), "{ vector_type } Vector"))' )
88
80
89
81
sql_upsert = f"""
90
- UPSERT INTO `{ table_name } ` (key, vec)
82
+ UPSERT INTO `{ table_name } ` (key, user, vec)
91
83
VALUES { "," .join (values )} ;
92
84
"""
93
85
with ydb .QuerySessionPool (self .driver ) as session_pool :
94
86
session_pool .execute_with_retries (sql_upsert )
95
87
96
- def select_from_index (self ):
97
- querys = []
98
- for vector_type in self .vector_types .keys ():
99
- for distance in self .targets .keys ():
100
- for distance_func in self .targets [distance ].keys ():
101
- order = "ASC" if distance != "similarity" else "DESC"
102
- vector = self .get_vector (f"{ vector_type } Vector" , 1 )
103
- querys .append (
104
- f"""
105
- $Target = { self .vector_types [vector_type ]} (Cast([{ vector } ] AS List<{ vector_type } >));
106
- SELECT key, vec, { self .targets [distance ][distance_func ]} (vec, $Target) as target
107
- FROM { vector_type } _{ distance } _{ distance_func }
108
- VIEW `{ self .index_name } `
109
- ORDER BY { self .targets [distance ][distance_func ]} (vec, $Target) { order }
110
- LIMIT { self .rows_count } ;
111
- """
112
- )
113
- for _ in self .roll ():
114
- with ydb .QuerySessionPool (self .driver ) as session_pool :
115
- for qyery in querys :
116
- result_sets = session_pool .execute_with_retries (qyery )
88
+ def _get_queries (self ):
89
+ queries = []
90
+ for prefixed in ['' , '_pfx' ]:
91
+ for vector_type in self .vector_types .keys ():
92
+ for distance in self .targets .keys ():
93
+ for distance_func in self .targets [distance ].keys ():
94
+ table_name = f"{ vector_type } _{ distance } _{ distance_func } { prefixed } "
95
+ order = "ASC" if distance != "similarity" else "DESC"
96
+ vector = self .get_vector (f"{ vector_type } Vector" , 1 )
97
+ where = ""
98
+ if prefixed :
99
+ where = "WHERE user=1"
100
+ queries .append ([
101
+ True , f"""
102
+ $Target = { self .vector_types [vector_type ]} (Cast([{ vector } ] AS List<{ vector_type } >));
103
+ SELECT key, vec, { self .targets [distance ][distance_func ]} (vec, $Target) as target
104
+ FROM `{ table_name } `
105
+ VIEW `{ self .index_name } `
106
+ { where }
107
+ ORDER BY { self .targets [distance ][distance_func ]} (vec, $Target) { order }
108
+ LIMIT { self .rows_count } ;"""
109
+ ])
110
+ # Insert, update, upsert, delete
111
+ key = self .rows_count + 1
112
+ vector = self .get_vector (vector_type , key + 1 )
113
+ queries .append ([
114
+ False , f"""
115
+ INSERT INTO `{ table_name } ` (key, user, vec)
116
+ VALUES ({ key } , { 1 + (key ) % self .rows_per_user } ,
117
+ Untag({ self .vector_types [vector_type ]} ([{ vector } ]), "{ vector_type } Vector"))
118
+ """
119
+ ])
120
+ vector = self .get_vector (vector_type , key + 2 )
121
+ queries .append ([
122
+ False , f"""
123
+ UPDATE `{ table_name } ` SET user=user+1,
124
+ vec=Untag({ self .vector_types [vector_type ]} ([{ vector } ]), "{ vector_type } Vector")
125
+ WHERE key={ key }
126
+ """
127
+ ])
128
+ vector = self .get_vector (vector_type , key + 3 )
129
+ queries .append ([
130
+ False , f"""
131
+ UPSERT INTO `{ table_name } ` (key, user, vec)
132
+ VALUES ({ key } , { 1 + (key ) % self .rows_per_user } ,
133
+ Untag({ self .vector_types [vector_type ]} ([{ vector } ]), "{ vector_type } Vector"))
134
+ """
135
+ ])
136
+ queries .append ([
137
+ False , f"""
138
+ DELETE FROM `{ table_name } ` WHERE key={ key }
139
+ """
140
+ ])
141
+ return queries
142
+
143
+ def _do_queries (self , queries ):
144
+ with ydb .QuerySessionPool (self .driver ) as session_pool :
145
+ for [is_select , query ] in queries :
146
+ result_sets = session_pool .execute_with_retries (query )
147
+ if is_select :
117
148
assert len (result_sets [0 ].rows ) > 0 , "Query returned an empty set"
118
149
rows = result_sets [0 ].rows
119
150
for row in rows :
120
151
assert row ['target' ] is not None , "the distance is None"
121
152
153
+ def select_from_index (self ):
154
+ queries = self ._get_queries ()
155
+ for _ in self .roll ():
156
+ self ._do_queries (queries )
157
+
122
158
def select_from_index_without_roll (self ):
123
- querys = []
124
- for vector_type in self .vector_types .keys ():
125
- for distance in self .targets .keys ():
126
- for distance_func in self .targets [distance ].keys ():
127
- order = "ASC" if distance != "similarity" else "DESC"
128
- vector = self .get_vector (f"{ vector_type } Vector" , 1 )
129
- querys .append (
130
- f"""
131
- $Target = { self .vector_types [vector_type ]} (Cast([{ vector } ] AS List<{ vector_type } >));
132
- SELECT key, vec, { self .targets [distance ][distance_func ]} (vec, $Target) as target
133
- FROM { vector_type } _{ distance } _{ distance_func }
134
- VIEW `{ self .index_name } `
135
- ORDER BY { self .targets [distance ][distance_func ]} (vec, $Target) { order }
136
- LIMIT { self .rows_count } ;
137
- """
138
- )
139
- with ydb .QuerySessionPool (self .driver ) as session_pool :
140
- for query in querys :
141
- result_sets = session_pool .execute_with_retries (query )
142
- assert len (result_sets [0 ].rows ) > 0 , "Query returned an empty set"
143
- rows = result_sets [0 ].rows
144
- for row in rows :
145
- assert row ['target' ] is not None , "the distance is None"
159
+ queries = self ._get_queries ()
160
+ self ._do_queries (queries )
146
161
147
162
def create_table (self , table_name ):
148
163
query = f"""
149
164
CREATE TABLE { table_name } (
150
165
key Int64 NOT NULL,
166
+ user Uint64 NOT NULL,
151
167
vec String NOT NULL,
152
168
PRIMARY KEY (key)
153
169
)
@@ -156,26 +172,22 @@ def create_table(self, table_name):
156
172
session_pool .execute_with_retries (query )
157
173
158
174
def test_vector_index (self ):
159
- for vector_type in self .vector_types .keys ():
160
- for distance in self .targets .keys ():
161
- for distance_func in self .targets [distance ].keys ():
162
- self .create_table (table_name = f"{ vector_type } _{ distance } _{ distance_func } " )
163
- self .write_data (
164
- name = self .vector_types [vector_type ],
165
- vector_type = f"{ vector_type } Vector" ,
166
- table_name = f"{ vector_type } _{ distance } _{ distance_func } " ,
167
- )
168
- if distance == "similarity" :
169
- self ._create_index (
170
- table_name = f"{ vector_type } _{ distance } _{ distance_func } " ,
175
+ for prefixed in ['' , '_pfx' ]:
176
+ for vector_type in self .vector_types .keys ():
177
+ for distance in self .targets .keys ():
178
+ for distance_func in self .targets [distance ].keys ():
179
+ table_name = f"{ vector_type } _{ distance } _{ distance_func } { prefixed } "
180
+ self .create_table (table_name )
181
+ self ._write_data (
182
+ name = self .vector_types [vector_type ],
171
183
vector_type = vector_type ,
172
- similarity = distance_func ,
184
+ table_name = table_name ,
173
185
)
174
- else :
175
186
self ._create_index (
176
- table_name = f" { vector_type } _ { distance } _ { distance_func } " ,
187
+ table_name = table_name ,
177
188
vector_type = vector_type ,
178
- distance = distance_func ,
189
+ target = f"{ distance } ={ distance_func } " ,
190
+ prefixed = prefixed ,
179
191
)
180
192
self .wait_index_ready ()
181
193
self .select_from_index ()
0 commit comments