Skip to content

Commit e315932

Browse files
committed
tests: check is query for query points batch
1 parent 0349bf8 commit e315932

File tree

1 file changed

+32
-15
lines changed

1 file changed

+32
-15
lines changed

tests/embed_tests/test_local_inference.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -771,40 +771,57 @@ def test_query_batch_points(prefer_grpc):
771771
local_client._client.query_batch_points, local_kwargs
772772
)
773773

774-
dense_doc_1 = models.Document(text="hello world", model=DENSE_MODEL_NAME)
775-
dense_doc_2 = models.Document(text="bye world", model=DENSE_MODEL_NAME)
776-
dense_doc_3 = models.Document(text="goodbye world", model=DENSE_MODEL_NAME)
777-
dense_doc_4 = models.Document(text="good afternoon world", model=DENSE_MODEL_NAME)
778-
dense_doc_5 = models.Document(text="good morning world", model=DENSE_MODEL_NAME)
774+
sparse_doc_1 = models.Document(text="hello world", model=SPARSE_MODEL_NAME)
775+
sparse_doc_2 = models.Document(text="bye world", model=SPARSE_MODEL_NAME)
776+
sparse_doc_3 = models.Document(text="goodbye world", model=SPARSE_MODEL_NAME)
777+
sparse_doc_4 = models.Document(text="good afternoon world", model=SPARSE_MODEL_NAME)
778+
sparse_doc_5 = models.Document(text="good morning world", model=SPARSE_MODEL_NAME)
779779

780780
points = [
781-
models.PointStruct(id=i, vector=dense_doc)
781+
models.PointStruct(id=i, vector={"sparse-text": dense_doc})
782782
for i, dense_doc in enumerate(
783-
[dense_doc_1, dense_doc_2, dense_doc_3, dense_doc_4, dense_doc_5]
783+
[sparse_doc_1, sparse_doc_2, sparse_doc_3, sparse_doc_4, sparse_doc_5]
784784
)
785785
]
786786

787-
populate_dense_collection(local_client, points)
788-
populate_dense_collection(remote_client, points)
787+
populate_sparse_collection(local_client, points, vector_name="sparse-text")
788+
populate_sparse_collection(remote_client, points, vector_name="sparse-text")
789789

790-
prefetch_1 = models.Prefetch(query=models.NearestQuery(nearest=dense_doc_2), limit=3)
791-
prefetch_2 = models.Prefetch(query=models.NearestQuery(nearest=dense_doc_3), limit=3)
790+
prefetch_1 = models.Prefetch(
791+
query=models.NearestQuery(nearest=sparse_doc_2), limit=3, using="sparse-text"
792+
)
793+
prefetch_2 = models.Prefetch(
794+
query=models.NearestQuery(nearest=sparse_doc_3), limit=3, using="sparse-text"
795+
)
792796

793797
query_requests = [
794-
models.QueryRequest(query=models.NearestQuery(nearest=dense_doc_1)),
798+
models.QueryRequest(query=models.NearestQuery(nearest=sparse_doc_1), using="sparse-text"),
795799
models.QueryRequest(
796-
query=models.NearestQuery(nearest=dense_doc_2), prefetch=[prefetch_1, prefetch_2]
800+
query=models.NearestQuery(nearest=sparse_doc_2),
801+
prefetch=[prefetch_1, prefetch_2],
802+
using="sparse-text",
797803
),
798804
]
799805

800806
local_client.query_batch_points(COLLECTION_NAME, query_requests)
801807
remote_client.query_batch_points(COLLECTION_NAME, query_requests)
802808
current_requests = local_kwargs["requests"]
803-
assert all([isinstance(request.query.nearest, list) for request in current_requests])
804809
assert all(
805-
[isinstance(prefetch.query.nearest, list) for prefetch in current_requests[1].prefetch]
810+
[isinstance(request.query.nearest, models.SparseVector) for request in current_requests]
811+
)
812+
assert all(
813+
[
814+
isinstance(prefetch.query.nearest, models.SparseVector)
815+
for prefetch in current_requests[1].prefetch
816+
]
806817
)
807818

819+
retrieved_point = local_client.retrieve(COLLECTION_NAME, ids=[0], with_vectors=True)[0]
820+
assert not np.allclose(
821+
retrieved_point.vector["sparse-text"].values,
822+
current_requests[0].query.nearest.values,
823+
atol=1e-3,
824+
)
808825
local_client.delete_collection(COLLECTION_NAME)
809826
remote_client.delete_collection(COLLECTION_NAME)
810827

0 commit comments

Comments
 (0)