Skip to content

Commit 17ae681

Browse files
authored
Add prefixed vector index and vector index update to compatibility tests (#26481)
1 parent e7e56f0 commit 17ae681

File tree

1 file changed

+101
-89
lines changed

1 file changed

+101
-89
lines changed

ydb/tests/compatibility/test_vector_index.py

Lines changed: 101 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class TestVectorIndex(RollingUpgradeAndDowngradeFixture):
1111
def setup(self):
1212
if min(self.versions) < (25, 1):
1313
pytest.skip("Only available since 25-1")
14-
self.rows_count = 5
14+
self.rows_count = 9
15+
self.rows_per_user = 3
1516
self.index_name = "vector_idx"
1617
self.vector_dimension = 3
1718
self.vector_types = {
@@ -40,31 +41,21 @@ def get_vector(self, type, numb):
4041
values.append(numb)
4142
return ",".join(str(val) for val in values)
4243

43-
def _create_index(self, vector_type, table_name, distance=None, similarity=None):
44-
if distance is not None:
45-
create_index_sql = f"""
46-
ALTER TABLE {table_name}
47-
ADD INDEX `{self.index_name}` GLOBAL USING vector_kmeans_tree
48-
ON (vec)
49-
WITH (distance={distance},
50-
vector_type={vector_type},
51-
vector_dimension={self.vector_dimension},
52-
levels=2,
53-
clusters=10
54-
);
55-
"""
56-
else:
57-
create_index_sql = f"""
58-
ALTER TABLE {table_name}
59-
ADD INDEX `{self.index_name}` GLOBAL USING vector_kmeans_tree
60-
ON (vec)
61-
WITH (similarity={similarity},
62-
vector_type={vector_type},
63-
vector_dimension={self.vector_dimension},
64-
levels=2,
65-
clusters=10
66-
);
67-
"""
44+
def _create_index(self, vector_type, table_name, target, prefixed):
45+
prefix = ""
46+
if prefixed:
47+
prefix = "user, "
48+
create_index_sql = f"""
49+
ALTER TABLE {table_name}
50+
ADD INDEX `{self.index_name}` GLOBAL USING vector_kmeans_tree
51+
ON ({prefix}vec)
52+
WITH ({target},
53+
vector_type={vector_type},
54+
vector_dimension={self.vector_dimension},
55+
levels=2,
56+
clusters=10
57+
);
58+
"""
6859
with ydb.QuerySessionPool(self.driver) as session_pool:
6960
session_pool.execute_with_retries(create_index_sql)
7061

@@ -80,74 +71,99 @@ def predicate():
8071

8172
assert wait_for(predicate, timeout_seconds=100, step_seconds=1), "Error getting index status"
8273

83-
def write_data(self, name, vector_type, table_name):
74+
def _write_data(self, name, vector_type, table_name):
8475
values = []
8576
for key in range(self.rows_count):
8677
vector = self.get_vector(vector_type, key + 1)
87-
values.append(f'({key}, Untag({name}([{vector}]), "{vector_type}"))')
78+
user = 1 + (key % self.rows_per_user)
79+
values.append(f'({key}, {user}, Untag({name}([{vector}]), "{vector_type}Vector"))')
8880

8981
sql_upsert = f"""
90-
UPSERT INTO `{table_name}` (key, vec)
82+
UPSERT INTO `{table_name}` (key, user, vec)
9183
VALUES {",".join(values)};
9284
"""
9385
with ydb.QuerySessionPool(self.driver) as session_pool:
9486
session_pool.execute_with_retries(sql_upsert)
9587

96-
def select_from_index(self):
97-
querys = []
98-
for vector_type in self.vector_types.keys():
99-
for distance in self.targets.keys():
100-
for distance_func in self.targets[distance].keys():
101-
order = "ASC" if distance != "similarity" else "DESC"
102-
vector = self.get_vector(f"{vector_type}Vector", 1)
103-
querys.append(
104-
f"""
105-
$Target = {self.vector_types[vector_type]}(Cast([{vector}] AS List<{vector_type}>));
106-
SELECT key, vec, {self.targets[distance][distance_func]}(vec, $Target) as target
107-
FROM {vector_type}_{distance}_{distance_func}
108-
VIEW `{self.index_name}`
109-
ORDER BY {self.targets[distance][distance_func]}(vec, $Target) {order}
110-
LIMIT {self.rows_count};
111-
"""
112-
)
113-
for _ in self.roll():
114-
with ydb.QuerySessionPool(self.driver) as session_pool:
115-
for qyery in querys:
116-
result_sets = session_pool.execute_with_retries(qyery)
88+
def _get_queries(self):
89+
queries = []
90+
for prefixed in ['', '_pfx']:
91+
for vector_type in self.vector_types.keys():
92+
for distance in self.targets.keys():
93+
for distance_func in self.targets[distance].keys():
94+
table_name = f"{vector_type}_{distance}_{distance_func}{prefixed}"
95+
order = "ASC" if distance != "similarity" else "DESC"
96+
vector = self.get_vector(f"{vector_type}Vector", 1)
97+
where = ""
98+
if prefixed:
99+
where = "WHERE user=1"
100+
queries.append([
101+
True, f"""
102+
$Target = {self.vector_types[vector_type]}(Cast([{vector}] AS List<{vector_type}>));
103+
SELECT key, vec, {self.targets[distance][distance_func]}(vec, $Target) as target
104+
FROM `{table_name}`
105+
VIEW `{self.index_name}`
106+
{where}
107+
ORDER BY {self.targets[distance][distance_func]}(vec, $Target) {order}
108+
LIMIT {self.rows_count};"""
109+
])
110+
# Insert, update, upsert, delete
111+
key = self.rows_count+1
112+
vector = self.get_vector(vector_type, key+1)
113+
queries.append([
114+
False, f"""
115+
INSERT INTO `{table_name}` (key, user, vec)
116+
VALUES ({key}, {1 + (key) % self.rows_per_user},
117+
Untag({self.vector_types[vector_type]}([{vector}]), "{vector_type}Vector"))
118+
"""
119+
])
120+
vector = self.get_vector(vector_type, key+2)
121+
queries.append([
122+
False, f"""
123+
UPDATE `{table_name}` SET user=user+1,
124+
vec=Untag({self.vector_types[vector_type]}([{vector}]), "{vector_type}Vector")
125+
WHERE key={key}
126+
"""
127+
])
128+
vector = self.get_vector(vector_type, key+3)
129+
queries.append([
130+
False, f"""
131+
UPSERT INTO `{table_name}` (key, user, vec)
132+
VALUES ({key}, {1 + (key) % self.rows_per_user},
133+
Untag({self.vector_types[vector_type]}([{vector}]), "{vector_type}Vector"))
134+
"""
135+
])
136+
queries.append([
137+
False, f"""
138+
DELETE FROM `{table_name}` WHERE key={key}
139+
"""
140+
])
141+
return queries
142+
143+
def _do_queries(self, queries):
144+
with ydb.QuerySessionPool(self.driver) as session_pool:
145+
for [is_select, query] in queries:
146+
result_sets = session_pool.execute_with_retries(query)
147+
if is_select:
117148
assert len(result_sets[0].rows) > 0, "Query returned an empty set"
118149
rows = result_sets[0].rows
119150
for row in rows:
120151
assert row['target'] is not None, "the distance is None"
121152

153+
def select_from_index(self):
154+
queries = self._get_queries()
155+
for _ in self.roll():
156+
self._do_queries(queries)
157+
122158
def select_from_index_without_roll(self):
123-
querys = []
124-
for vector_type in self.vector_types.keys():
125-
for distance in self.targets.keys():
126-
for distance_func in self.targets[distance].keys():
127-
order = "ASC" if distance != "similarity" else "DESC"
128-
vector = self.get_vector(f"{vector_type}Vector", 1)
129-
querys.append(
130-
f"""
131-
$Target = {self.vector_types[vector_type]}(Cast([{vector}] AS List<{vector_type}>));
132-
SELECT key, vec, {self.targets[distance][distance_func]}(vec, $Target) as target
133-
FROM {vector_type}_{distance}_{distance_func}
134-
VIEW `{self.index_name}`
135-
ORDER BY {self.targets[distance][distance_func]}(vec, $Target) {order}
136-
LIMIT {self.rows_count};
137-
"""
138-
)
139-
with ydb.QuerySessionPool(self.driver) as session_pool:
140-
for query in querys:
141-
result_sets = session_pool.execute_with_retries(query)
142-
assert len(result_sets[0].rows) > 0, "Query returned an empty set"
143-
rows = result_sets[0].rows
144-
for row in rows:
145-
assert row['target'] is not None, "the distance is None"
159+
queries = self._get_queries()
160+
self._do_queries(queries)
146161

147162
def create_table(self, table_name):
148163
query = f"""
149164
CREATE TABLE {table_name} (
150165
key Int64 NOT NULL,
166+
user Uint64 NOT NULL,
151167
vec String NOT NULL,
152168
PRIMARY KEY (key)
153169
)
@@ -156,26 +172,22 @@ def create_table(self, table_name):
156172
session_pool.execute_with_retries(query)
157173

158174
def test_vector_index(self):
159-
for vector_type in self.vector_types.keys():
160-
for distance in self.targets.keys():
161-
for distance_func in self.targets[distance].keys():
162-
self.create_table(table_name=f"{vector_type}_{distance}_{distance_func}")
163-
self.write_data(
164-
name=self.vector_types[vector_type],
165-
vector_type=f"{vector_type}Vector",
166-
table_name=f"{vector_type}_{distance}_{distance_func}",
167-
)
168-
if distance == "similarity":
169-
self._create_index(
170-
table_name=f"{vector_type}_{distance}_{distance_func}",
175+
for prefixed in ['', '_pfx']:
176+
for vector_type in self.vector_types.keys():
177+
for distance in self.targets.keys():
178+
for distance_func in self.targets[distance].keys():
179+
table_name = f"{vector_type}_{distance}_{distance_func}{prefixed}"
180+
self.create_table(table_name)
181+
self._write_data(
182+
name=self.vector_types[vector_type],
171183
vector_type=vector_type,
172-
similarity=distance_func,
184+
table_name=table_name,
173185
)
174-
else:
175186
self._create_index(
176-
table_name=f"{vector_type}_{distance}_{distance_func}",
187+
table_name=table_name,
177188
vector_type=vector_type,
178-
distance=distance_func,
189+
target=f"{distance}={distance_func}",
190+
prefixed=prefixed,
179191
)
180192
self.wait_index_ready()
181193
self.select_from_index()

0 commit comments

Comments
 (0)