Skip to content

Commit e974108

Browse files
authored
Merge pull request #1241 from tibor-reiss/feat_1168_fetch_objects_by_ids
feature: fetch_objects_by_ids
2 parents a89feaa + dcfb8b2 commit e974108

File tree

9 files changed

+782
-1
lines changed

9 files changed

+782
-1
lines changed

integration/test_collection_async.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import datetime
22
import uuid
3+
from typing import Iterable
34

45
import pytest
56

67
import weaviate.classes as wvc
78
from weaviate.collections.classes.config import DataType, Property
9+
from weaviate.collections.classes.data import DataObject
10+
from weaviate.types import UUID
811

912
from .conftest import AsyncCollectionFactory, AsyncOpenAICollectionFactory
1013

1114
UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
15+
UUID2 = uuid.uuid4()
16+
UUID3 = uuid.uuid4()
1217

1318
DATE1 = datetime.datetime.strptime("2012-02-09", "%Y-%m-%d").replace(tzinfo=datetime.timezone.utc)
1419

@@ -32,6 +37,51 @@ async def test_fetch_objects(async_collection_factory: AsyncCollectionFactory) -
3237
assert res.objects[0].properties["name"] == "John Doe"
3338

3439

40+
@pytest.mark.asyncio
41+
@pytest.mark.parametrize(
42+
"ids, expected_len, expected",
43+
[
44+
([], 0, set()),
45+
((), 0, set()),
46+
(
47+
[
48+
UUID3,
49+
],
50+
1,
51+
{
52+
UUID3,
53+
},
54+
),
55+
([UUID1, UUID2], 2, {UUID1, UUID2}),
56+
((UUID1, UUID3), 2, {UUID1, UUID3}),
57+
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
58+
],
59+
)
60+
async def test_fetch_objects_by_ids(
61+
async_collection_factory: AsyncCollectionFactory,
62+
ids: Iterable[UUID],
63+
expected_len: int,
64+
expected: set,
65+
) -> None:
66+
collection = await async_collection_factory(
67+
properties=[
68+
Property(name="name", data_type=DataType.TEXT),
69+
],
70+
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
71+
)
72+
await collection.data.insert_many(
73+
[
74+
DataObject(properties={"name": "first"}, uuid=UUID1),
75+
DataObject(properties={"name": "second"}, uuid=UUID2),
76+
DataObject(properties={"name": "third"}, uuid=UUID3),
77+
]
78+
)
79+
80+
res = await collection.query.fetch_objects_by_ids(ids)
81+
assert len(res.objects) == expected_len
82+
assert {o.uuid for o in res.objects} == expected
83+
84+
3585
@pytest.mark.asyncio
3686
async def test_config_update(async_collection_factory: AsyncCollectionFactory) -> None:
3787
collection = await async_collection_factory(
@@ -200,3 +250,52 @@ async def test_generate(async_openai_collection: AsyncOpenAICollectionFactory) -
200250
assert len(res.objects) == 2
201251
for obj in res.objects:
202252
assert obj.generated is not None
253+
254+
255+
@pytest.mark.asyncio
256+
@pytest.mark.parametrize(
257+
"ids, expected_len, expected",
258+
[
259+
([], 0, set()),
260+
((), 0, set()),
261+
(
262+
[
263+
UUID3,
264+
],
265+
1,
266+
{
267+
UUID3,
268+
},
269+
),
270+
([UUID1, UUID2], 2, {UUID1, UUID2}),
271+
((UUID1, UUID3), 2, {UUID1, UUID3}),
272+
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
273+
],
274+
)
275+
async def test_generate_by_ids(
276+
async_openai_collection: AsyncOpenAICollectionFactory,
277+
ids: Iterable[UUID],
278+
expected_len: int,
279+
expected: set,
280+
) -> None:
281+
collection = await async_openai_collection(
282+
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
283+
)
284+
await collection.data.insert_many(
285+
[
286+
DataObject(properties={"text": "John Doe"}, uuid=UUID1),
287+
DataObject(properties={"text": "Jane Doe"}, uuid=UUID2),
288+
DataObject(properties={"text": "J. Doe"}, uuid=UUID3),
289+
]
290+
)
291+
res = await collection.generate.fetch_objects_by_ids(
292+
ids,
293+
single_prompt="Who is this? {text}",
294+
grouped_task="Who are these people?",
295+
)
296+
assert res is not None
297+
assert res.generated is not None
298+
assert len(res.objects) == expected_len
299+
assert {o.uuid for o in res.objects} == expected
300+
for obj in res.objects:
301+
assert obj.generated is not None

integration/test_collection_filter.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import time
33
import uuid
4-
from typing import Callable, List, Optional
4+
from typing import Callable, Iterable, List, Optional
55

66
import pytest as pytest
77

@@ -21,6 +21,7 @@
2121
)
2222
from weaviate.collections.classes.grpc import MetadataQuery, QueryReference, Sort
2323
from weaviate.collections.classes.internal import ReferenceToMulti
24+
from weaviate.types import UUID
2425

2526
NOW = datetime.datetime.now(datetime.timezone.utc)
2627
LATER = NOW + datetime.timedelta(hours=1)
@@ -548,6 +549,52 @@ def test_filter_id(collection_factory: CollectionFactory, weav_filter: _FilterVa
548549
assert objects[0].uuid == UUID1
549550

550551

552+
@pytest.mark.parametrize(
553+
"ids, expected_len, expected",
554+
[
555+
([], 0, set()),
556+
((), 0, set()),
557+
(
558+
[
559+
UUID3,
560+
],
561+
1,
562+
{
563+
UUID3,
564+
},
565+
),
566+
([UUID1, UUID2], 2, {UUID1, UUID2}),
567+
((UUID1, UUID3), 2, {UUID1, UUID3}),
568+
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
569+
],
570+
)
571+
def test_filter_ids(
572+
collection_factory: CollectionFactory,
573+
ids: Iterable[UUID],
574+
expected_len: int,
575+
expected: set,
576+
) -> None:
577+
collection = collection_factory(
578+
properties=[
579+
Property(name="Name", data_type=DataType.TEXT),
580+
],
581+
vectorizer_config=Configure.Vectorizer.none(),
582+
)
583+
584+
collection.data.insert_many(
585+
[
586+
DataObject(properties={"name": "first"}, uuid=UUID1),
587+
DataObject(properties={"name": "second"}, uuid=UUID2),
588+
DataObject(properties={"name": "third"}, uuid=UUID3),
589+
]
590+
)
591+
592+
objects = collection.query.fetch_objects_by_ids(ids).objects
593+
594+
assert len(objects) == expected_len
595+
assert {o.uuid for o in objects} == expected
596+
597+
551598
@pytest.mark.parametrize("path", ["_creationTimeUnix", "_lastUpdateTimeUnix"])
552599
def test_filter_timestamp_direct_path(collection_factory: CollectionFactory, path: str) -> None:
553600
collection = collection_factory(

weaviate/collections/generate.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
_FetchObjectsGenerateAsync,
88
_FetchObjectsGenerate,
99
)
10+
from weaviate.collections.queries.fetch_objects_by_ids import (
11+
_FetchObjectsByIDsGenerateAsync,
12+
_FetchObjectsByIDsGenerate,
13+
)
1014
from weaviate.collections.queries.hybrid import _HybridGenerateAsync, _HybridGenerate
1115
from weaviate.collections.queries.near_image import _NearImageGenerateAsync, _NearImageGenerate
1216
from weaviate.collections.queries.near_media import _NearMediaGenerateAsync, _NearMediaGenerate
@@ -19,6 +23,7 @@ class _GenerateCollectionAsync(
1923
Generic[TProperties, References],
2024
_BM25GenerateAsync[TProperties, References],
2125
_FetchObjectsGenerateAsync[TProperties, References],
26+
_FetchObjectsByIDsGenerateAsync[TProperties, References],
2227
_HybridGenerateAsync[TProperties, References],
2328
_NearImageGenerateAsync[TProperties, References],
2429
_NearMediaGenerateAsync[TProperties, References],
@@ -33,6 +38,7 @@ class _GenerateCollection(
3338
Generic[TProperties, References],
3439
_BM25Generate[TProperties, References],
3540
_FetchObjectsGenerate[TProperties, References],
41+
_FetchObjectsByIDsGenerate[TProperties, References],
3642
_HybridGenerate[TProperties, References],
3743
_NearImageGenerate[TProperties, References],
3844
_NearMediaGenerate[TProperties, References],
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .generate import _FetchObjectsByIDsGenerateAsync, _FetchObjectsByIDsGenerate
2+
from .query import _FetchObjectsByIDsQueryAsync, _FetchObjectsByIDsQuery
3+
4+
__all__ = [
5+
"_FetchObjectsByIDsGenerate",
6+
"_FetchObjectsByIDsGenerateAsync",
7+
"_FetchObjectsByIDsQuery",
8+
"_FetchObjectsByIDsQueryAsync",
9+
]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Generic, Iterable, List, Optional
2+
3+
from weaviate import syncify
4+
from weaviate.collections.classes.filters import Filter
5+
from weaviate.collections.classes.grpc import METADATA, Sorting
6+
from weaviate.collections.classes.internal import (
7+
GenerativeReturnType,
8+
_Generative,
9+
ReturnProperties,
10+
ReturnReferences,
11+
_QueryOptions,
12+
)
13+
from weaviate.collections.classes.types import Properties, TProperties, References, TReferences
14+
from weaviate.collections.queries.base import _Base
15+
from weaviate.proto.v1 import search_get_pb2
16+
from weaviate.types import UUID, INCLUDE_VECTOR
17+
18+
19+
class _FetchObjectsByIDsGenerateAsync(
20+
Generic[Properties, References], _Base[Properties, References]
21+
):
22+
async def fetch_objects_by_ids(
23+
self,
24+
ids: Iterable[UUID],
25+
*,
26+
single_prompt: Optional[str] = None,
27+
grouped_task: Optional[str] = None,
28+
grouped_properties: Optional[List[str]] = None,
29+
limit: Optional[int] = None,
30+
offset: Optional[int] = None,
31+
after: Optional[UUID] = None,
32+
sort: Optional[Sorting] = None,
33+
include_vector: INCLUDE_VECTOR = False,
34+
return_metadata: Optional[METADATA] = None,
35+
return_properties: Optional[ReturnProperties[TProperties]] = None,
36+
return_references: Optional[ReturnReferences[TReferences]] = None
37+
) -> GenerativeReturnType[Properties, References, TProperties, TReferences]:
38+
"""Special case of fetch_objects based on filters on uuid"""
39+
if not ids:
40+
res = search_get_pb2.SearchReply(results=None)
41+
else:
42+
res = await self._query.get(
43+
limit=limit,
44+
offset=offset,
45+
after=after,
46+
filters=Filter.any_of([Filter.by_id().equal(uuid) for uuid in ids]),
47+
sort=sort,
48+
return_metadata=self._parse_return_metadata(return_metadata, include_vector),
49+
return_properties=self._parse_return_properties(return_properties),
50+
return_references=self._parse_return_references(return_references),
51+
generative=_Generative(
52+
single=single_prompt,
53+
grouped=grouped_task,
54+
grouped_properties=grouped_properties,
55+
),
56+
)
57+
return self._result_to_generative_query_return(
58+
res,
59+
_QueryOptions.from_input(
60+
return_metadata,
61+
return_properties,
62+
include_vector,
63+
self._references,
64+
return_references,
65+
),
66+
return_properties,
67+
return_references,
68+
)
69+
70+
71+
@syncify.convert
72+
class _FetchObjectsByIDsGenerate(
73+
Generic[Properties, References], _FetchObjectsByIDsGenerateAsync[Properties, References]
74+
):
75+
pass

0 commit comments

Comments
 (0)