Skip to content

Commit 311a5a7

Browse files
committed
deprecate: replace discover and context with query points in test_discovery
1 parent 6d6b81f commit 311a5a7

File tree

1 file changed

+117
-76
lines changed

1 file changed

+117
-76
lines changed

tests/congruence_tests/test_discovery.py

Lines changed: 117 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ def test_context_cosine(
6262
grpc_client,
6363
):
6464
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
65-
return client.discover(
65+
# test single context pair
66+
return client.query_points(
6667
collection_name=COLLECTION_NAME,
67-
context=[models.ContextExamplePair(positive=10, negative=19)],
68+
query=models.ContextQuery(context=models.ContextPair(positive=10, negative=19)),
6869
with_payload=True,
6970
limit=1000,
7071
using="text",
71-
)
72+
).points
7273

7374
compare_client_results(grpc_client, http_client, f, is_context_search=True)
7475
compare_client_results(local_client, http_client, f, is_context_search=True)
@@ -80,13 +81,14 @@ def test_context_dot(
8081
grpc_client,
8182
):
8283
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
83-
return client.discover(
84+
# test list context pair
85+
return client.query_points(
8486
collection_name=COLLECTION_NAME,
85-
context=[models.ContextExamplePair(positive=10, negative=19)],
87+
query=models.ContextQuery(context=models.ContextPair(positive=10, negative=19)),
8688
with_payload=True,
8789
limit=1000,
8890
using="image",
89-
)
91+
).points
9092

9193
compare_client_results(grpc_client, http_client, f, is_context_search=True)
9294
compare_client_results(local_client, http_client, f, is_context_search=True)
@@ -98,13 +100,13 @@ def test_context_euclidean(
98100
grpc_client,
99101
):
100102
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
101-
return client.discover(
103+
return client.query_points(
102104
collection_name=COLLECTION_NAME,
103-
context=[models.ContextExamplePair(positive=11, negative=19)],
105+
query=models.ContextQuery(context=models.ContextPair(positive=11, negative=19)),
104106
with_payload=True,
105107
limit=1000,
106108
using="code",
107-
)
109+
).points
108110

109111
compare_client_results(grpc_client, http_client, f, is_context_search=True)
110112
compare_client_results(local_client, http_client, f, is_context_search=True)
@@ -119,21 +121,23 @@ def test_context_many_pairs(
119121
random_image_vector_2 = random_vector(image_vector_size)
120122

121123
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
122-
return client.discover(
124+
return client.query_points(
123125
collection_name=COLLECTION_NAME,
124-
context=[
125-
models.ContextExamplePair(positive=11, negative=19),
126-
models.ContextExamplePair(positive=400, negative=200),
127-
models.ContextExamplePair(
128-
positive=random_image_vector_1, negative=random_image_vector_2
129-
),
130-
models.ContextExamplePair(positive=30, negative=random_image_vector_2),
131-
models.ContextExamplePair(positive=random_image_vector_1, negative=15),
132-
],
126+
query=models.ContextQuery(
127+
context=[
128+
models.ContextPair(positive=11, negative=19),
129+
models.ContextPair(positive=400, negative=200),
130+
models.ContextPair(
131+
positive=random_image_vector_1, negative=random_image_vector_2
132+
),
133+
models.ContextPair(positive=30, negative=random_image_vector_2),
134+
models.ContextPair(positive=random_image_vector_1, negative=15),
135+
]
136+
),
133137
with_payload=True,
134138
limit=1000,
135139
using="image",
136-
)
140+
).points
137141

138142
compare_client_results(grpc_client, http_client, f, is_context_search=True)
139143
compare_client_results(local_client, http_client, f, is_context_search=True)
@@ -145,14 +149,19 @@ def test_discover_cosine(
145149
grpc_client,
146150
):
147151
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
148-
return client.discover(
152+
# test single context pair
153+
return client.query_points(
149154
collection_name=COLLECTION_NAME,
150-
target=10,
151-
context=[models.ContextExamplePair(positive=11, negative=19)],
155+
query=models.DiscoverQuery(
156+
discover=models.DiscoverInput(
157+
target=10,
158+
context=models.ContextPair(positive=11, negative=19),
159+
)
160+
),
152161
with_payload=True,
153162
limit=10,
154163
using="text",
155-
)
164+
).points
156165

157166
compare_client_results(grpc_client, http_client, f)
158167
compare_client_results(local_client, http_client, f)
@@ -164,14 +173,18 @@ def test_discover_dot(
164173
grpc_client,
165174
):
166175
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
167-
return client.discover(
176+
# test list context pair
177+
return client.query_points(
168178
collection_name=COLLECTION_NAME,
169-
target=10,
170-
context=[models.ContextExamplePair(positive=11, negative=19)],
179+
query=models.DiscoverQuery(
180+
discover=models.DiscoverInput(
181+
target=10, context=[models.ContextPair(positive=11, negative=19)]
182+
)
183+
),
171184
with_payload=True,
172185
limit=10,
173186
using="image",
174-
)
187+
).points
175188

176189
compare_client_results(grpc_client, http_client, f)
177190
compare_client_results(local_client, http_client, f)
@@ -183,14 +196,17 @@ def test_discover_euclidean(
183196
grpc_client,
184197
):
185198
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
186-
return client.discover(
199+
return client.query_points(
187200
collection_name=COLLECTION_NAME,
188-
target=10,
189-
context=[models.ContextExamplePair(positive=11, negative=19)],
201+
query=models.DiscoverQuery(
202+
discover=models.DiscoverInput(
203+
target=10, context=[models.ContextPair(positive=11, negative=19)]
204+
)
205+
),
190206
with_payload=True,
191207
limit=10,
192208
using="code",
193-
)
209+
).points
194210

195211
compare_client_results(grpc_client, http_client, f)
196212
compare_client_results(local_client, http_client, f)
@@ -204,13 +220,17 @@ def test_discover_raw_target(
204220
random_image_vector = random_vector(image_vector_size)
205221

206222
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
207-
return client.discover(
223+
return client.query_points(
208224
collection_name=COLLECTION_NAME,
209-
target=random_image_vector,
210-
context=[models.ContextExamplePair(positive=10, negative=19)],
225+
query=models.DiscoverQuery(
226+
discover=models.DiscoverInput(
227+
target=random_image_vector,
228+
context=[models.ContextPair(positive=10, negative=19)],
229+
)
230+
),
211231
limit=10,
212232
using="image",
213-
)
233+
).points
214234

215235
compare_client_results(grpc_client, http_client, f)
216236
compare_client_results(local_client, http_client, f)
@@ -224,13 +244,17 @@ def test_context_raw_positive(
224244
random_image_vector = random_vector(image_vector_size)
225245

226246
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
227-
return client.discover(
247+
return client.query_points(
228248
collection_name=COLLECTION_NAME,
229-
target=10,
230-
context=[models.ContextExamplePair(positive=random_image_vector, negative=19)],
249+
query=models.DiscoverQuery(
250+
discover=models.DiscoverInput(
251+
target=10,
252+
context=[models.ContextPair(positive=random_image_vector, negative=19)],
253+
)
254+
),
231255
limit=10,
232256
using="image",
233-
)
257+
).points
234258

235259
compare_client_results(grpc_client, http_client, f)
236260
compare_client_results(local_client, http_client, f)
@@ -242,13 +266,13 @@ def test_only_target(
242266
grpc_client,
243267
):
244268
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
245-
return client.discover(
269+
return client.query_points(
246270
collection_name=COLLECTION_NAME,
247-
target=10,
271+
query=models.DiscoverQuery(discover=models.DiscoverInput(target=10, context=[])),
248272
with_payload=True,
249273
limit=10,
250274
using="image",
251-
)
275+
).points
252276

253277
compare_client_results(grpc_client, http_client, f)
254278
compare_client_results(local_client, http_client, f)
@@ -261,20 +285,24 @@ def discover_from_another_collection(
261285
positive_point_id: Optional[int] = None,
262286
**kwargs: dict[str, Any],
263287
) -> list[models.ScoredPoint]:
264-
return client.discover(
288+
return client.query_points(
265289
collection_name=collection_name,
266-
target=5,
267-
context=[models.ContextExamplePair(positive=positive_point_id, negative=6)]
268-
if positive_point_id is not None
269-
else [],
290+
query=models.DiscoverQuery(
291+
discover=models.DiscoverInput(
292+
target=5,
293+
context=[models.ContextPair(positive=positive_point_id, negative=6)]
294+
if positive_point_id is not None
295+
else [],
296+
)
297+
),
270298
with_payload=True,
271299
limit=10,
272300
using="image",
273301
lookup_from=models.LookupLocation(
274302
collection=lookup_collection_name,
275303
vector="image",
276304
),
277-
)
305+
).points
278306

279307

280308
def test_discover_from_another_collection(
@@ -317,19 +345,26 @@ def test_discover_batch(
317345
http_client,
318346
grpc_client,
319347
):
320-
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[list[models.ScoredPoint]]:
321-
return client.discover_batch(
348+
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.QueryResponse]:
349+
return client.query_batch_points(
322350
collection_name=COLLECTION_NAME,
323351
requests=[
324-
models.DiscoverRequest(
325-
target=10,
326-
context=[models.ContextExamplePair(positive=15, negative=7)],
352+
models.QueryRequest(
353+
query=models.DiscoverQuery(
354+
discover=models.DiscoverInput(
355+
target=10,
356+
context=[models.ContextPair(positive=15, negative=7)],
357+
)
358+
),
327359
limit=5,
328360
using="image",
329361
),
330-
models.DiscoverRequest(
331-
target=11,
332-
context=[models.ContextExamplePair(positive=15, negative=17)],
362+
models.QueryRequest(
363+
query=models.DiscoverQuery(
364+
discover=models.DiscoverInput(
365+
target=11, context=[models.ContextPair(positive=15, negative=17)]
366+
)
367+
),
333368
limit=6,
334369
using="image",
335370
lookup_from=models.LookupLocation(
@@ -347,26 +382,32 @@ def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[list[models.ScoredPo
347382
@pytest.mark.parametrize("filter", [one_random_filter_please() for _ in range(10)])
348383
def test_discover_with_filters(local_client, http_client, grpc_client, filter: models.Filter):
349384
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
350-
return client.discover(
385+
return client.query_points(
351386
collection_name=COLLECTION_NAME,
352-
target=10,
353-
context=[models.ContextExamplePair(positive=15, negative=7)],
387+
query=models.DiscoverQuery(
388+
discover=models.DiscoverInput(
389+
target=10, context=[models.ContextPair(positive=15, negative=7)]
390+
)
391+
),
354392
limit=15,
355393
using="image",
356394
query_filter=filter,
357-
)
395+
).points
396+
397+
compare_client_results(grpc_client, http_client, f)
398+
compare_client_results(local_client, http_client, f)
358399

359400

360401
@pytest.mark.parametrize("filter", [one_random_filter_please() for _ in range(10)])
361402
def test_context_with_filters(local_client, http_client, grpc_client, filter: models.Filter):
362403
def f(client: QdrantBase, **kwargs: dict[str, Any]) -> list[models.ScoredPoint]:
363-
return client.discover(
404+
return client.query_points(
364405
collection_name=COLLECTION_NAME,
365-
context=[models.ContextExamplePair(positive=15, negative=7)],
406+
query=models.ContextQuery(context=models.ContextPair(positive=15, negative=7)),
366407
limit=1000,
367408
using="image",
368409
query_filter=filter,
369-
)
410+
).points
370411

371412
compare_client_results(grpc_client, http_client, f, is_context_search=True)
372413
compare_client_results(local_client, http_client, f, is_context_search=True)
@@ -386,38 +427,38 @@ def test_query_with_nan():
386427
init_client(remote_client, fixture_points)
387428

388429
with pytest.raises(AssertionError):
389-
local_client.discover(
430+
local_client.query_points(
390431
collection_name=COLLECTION_NAME,
391-
target=vector,
432+
query=models.DiscoverQuery(discover=models.DiscoverInput(target=vector, context=[])),
392433
using=using,
393434
)
394435
with pytest.raises(UnexpectedResponse):
395-
remote_client.discover(
436+
remote_client.query_points(
396437
collection_name=COLLECTION_NAME,
397-
target=vector,
438+
query=models.DiscoverQuery(discover=models.DiscoverInput(target=vector, context=[])),
398439
using=using,
399440
)
400441
with pytest.raises(AssertionError):
401-
local_client.discover(
442+
local_client.query_points(
402443
collection_name=COLLECTION_NAME,
403-
context=[models.ContextExamplePair(positive=vector, negative=1)],
444+
query=models.ContextQuery(context=models.ContextPair(positive=vector, negative=1)),
404445
using=using,
405446
)
406447
with pytest.raises(UnexpectedResponse):
407-
remote_client.discover(
448+
remote_client.query_points(
408449
collection_name=COLLECTION_NAME,
409-
context=[models.ContextExamplePair(positive=vector, negative=1)],
450+
query=models.ContextQuery(context=models.ContextPair(positive=vector, negative=1)),
410451
using=using,
411452
)
412453
with pytest.raises(AssertionError):
413-
local_client.discover(
454+
local_client.query_points(
414455
collection_name=COLLECTION_NAME,
415-
context=[models.ContextExamplePair(positive=1, negative=vector)],
456+
query=models.ContextQuery(context=models.ContextPair(positive=1, negative=vector)),
416457
using=using,
417458
)
418459
with pytest.raises(UnexpectedResponse):
419-
remote_client.discover(
460+
remote_client.query_points(
420461
collection_name=COLLECTION_NAME,
421-
context=[models.ContextExamplePair(positive=1, negative=vector)],
462+
query=models.ContextQuery(context=models.ContextPair(positive=1, negative=vector)),
422463
using=using,
423464
)

0 commit comments

Comments
 (0)