Skip to content

Commit 68d08ac

Browse files
RSDK-5731: add dataset apis (#538)
1 parent fd5af88 commit 68d08ac

File tree

5 files changed

+253
-5
lines changed

5 files changed

+253
-5
lines changed

src/viam/app/data_client.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import warnings
22
from datetime import datetime
33
from pathlib import Path
4-
from typing import Any, List, Mapping, Optional, Tuple
4+
from typing import Any, List, Mapping, Optional, Sequence, Tuple
55

66
from google.protobuf.struct_pb2 import Struct
77
from grpclib.client import Channel, Stream
88

99
from viam import logging
1010
from viam.proto.app.data import (
11+
AddBinaryDataToDatasetByIDsRequest,
1112
AddBoundingBoxToImageByIDRequest,
1213
AddBoundingBoxToImageByIDResponse,
1314
AddTagsToBinaryDataByFilterRequest,
@@ -32,6 +33,7 @@
3233
Filter,
3334
GetDatabaseConnectionRequest,
3435
GetDatabaseConnectionResponse,
36+
RemoveBinaryDataFromDatasetByIDsRequest,
3537
RemoveBoundingBoxFromImageByIDRequest,
3638
RemoveTagsFromBinaryDataByFilterRequest,
3739
RemoveTagsFromBinaryDataByFilterResponse,
@@ -42,6 +44,18 @@
4244
TagsByFilterRequest,
4345
TagsByFilterResponse,
4446
)
47+
from viam.proto.app.dataset import (
48+
CreateDatasetRequest,
49+
CreateDatasetResponse,
50+
Dataset,
51+
DatasetServiceStub,
52+
DeleteDatasetRequest,
53+
ListDatasetsByIDsRequest,
54+
ListDatasetsByIDsResponse,
55+
ListDatasetsByOrganizationIDRequest,
56+
ListDatasetsByOrganizationIDResponse,
57+
RenameDatasetRequest,
58+
)
4559
from viam.proto.app.datasync import (
4660
DataCaptureUploadMetadata,
4761
DataCaptureUploadRequest,
@@ -98,6 +112,7 @@ def __eq__(self, other: object) -> bool:
98112
return str(self) == str(other)
99113
return False
100114

115+
# TODO (RSDK-6684): Revisit if this shadow type is necessary
101116
class BinaryData:
102117
"""Class representing a piece of binary data and associated metadata.
103118
@@ -131,10 +146,12 @@ def __init__(self, channel: Channel, metadata: Mapping[str, str]):
131146
self._metadata = metadata
132147
self._data_client = DataServiceStub(channel)
133148
self._data_sync_client = DataSyncServiceStub(channel)
149+
self._dataset_client = DatasetServiceStub(channel)
134150
self._channel = channel
135151

136152
_data_client: DataServiceStub
137153
_data_sync_client: DataSyncServiceStub
154+
_dataset_client: DatasetServiceStub
138155
_metadata: Mapping[str, str]
139156
_channel: Channel
140157

@@ -478,6 +495,93 @@ async def get_database_connection(self, organization_id: str) -> str:
478495
async def configure_database_user(self, organization_id: str, password: str) -> None:
479496
raise NotImplementedError()
480497

498+
async def create_dataset(self, name: str, organization_id: str) -> str:
499+
"""Create a new dataset.
500+
501+
Args:
502+
name (str): The name of the dataset being created.
503+
organization_id (str): The ID of the organization where the dataset is being created.
504+
505+
Returns:
506+
str: The dataset ID of the created dataset.
507+
"""
508+
request = CreateDatasetRequest(name=name, organization_id=organization_id)
509+
response: CreateDatasetResponse = await self._dataset_client.CreateDataset(request, metadata=self._metadata)
510+
return response.id
511+
512+
async def list_dataset_by_ids(self, ids: List[str]) -> Sequence[Dataset]:
513+
"""Get a list of datasets using their IDs.
514+
515+
Args:
516+
ids (List[str]): The IDs of the datasets being called for.
517+
518+
Returns:
519+
Sequence[Dataset]: The list of datasets.
520+
"""
521+
request = ListDatasetsByIDsRequest(ids=ids)
522+
response: ListDatasetsByIDsResponse = await self._dataset_client.ListDatasetsByIDs(request, metadata=self._metadata)
523+
524+
return response.datasets
525+
526+
async def list_datasets_by_organization_id(self, organization_id: str) -> Sequence[Dataset]:
527+
"""Get the datasets in an organization.
528+
529+
Args:
530+
organization_id (str): The ID of the organization.
531+
532+
Returns:
533+
Sequence[Dataset]: The list of datasets in the organization.
534+
"""
535+
request = ListDatasetsByOrganizationIDRequest(organization_id=organization_id)
536+
response: ListDatasetsByOrganizationIDResponse = await self._dataset_client.ListDatasetsByOrganizationID(
537+
request, metadata=self._metadata
538+
)
539+
540+
return response.datasets
541+
542+
async def rename_dataset(self, id: str, name: str) -> None:
543+
"""Rename a dataset specified by the dataset ID.
544+
545+
Args:
546+
id (str): The ID of the dataset.
547+
name (str): The new name of the dataset.
548+
"""
549+
request = RenameDatasetRequest(id=id, name=name)
550+
await self._dataset_client.RenameDataset(request, metadata=self._metadata)
551+
552+
async def delete_dataset(self, id: str) -> None:
553+
"""Delete a dataset.
554+
555+
Args:
556+
id (str): The ID of the dataset.
557+
"""
558+
request = DeleteDatasetRequest(id=id)
559+
await self._dataset_client.DeleteDataset(request, metadata=self._metadata)
560+
561+
async def add_binary_data_to_dataset_by_ids(self, binary_ids: List[BinaryID], dataset_id: str) -> None:
562+
"""Add the BinaryData to the provided dataset.
563+
564+
This BinaryData will be tagged with the VIAM_DATASET_{id} label.
565+
566+
Args:
567+
binary_ids (List[BinaryID]): The IDs of binary data to add to dataset.
568+
dataset_id (str): The ID of the dataset to be added to.
569+
"""
570+
request = AddBinaryDataToDatasetByIDsRequest(binary_ids=binary_ids, dataset_id=dataset_id)
571+
await self._data_client.AddBinaryDataToDatasetByIDs(request, metadata=self._metadata)
572+
573+
async def remove_binary_data_from_dataset_by_ids(self, binary_ids: List[BinaryID], dataset_id: str) -> None:
574+
"""Remove the BinaryData from the provided dataset.
575+
576+
This BinaryData will lose the VIAM_DATASET_{id} tag.
577+
578+
Args:
579+
binary_ids (List[BinaryID]): The IDs of binary data to remove from dataset.
580+
dataset_id (str): The ID of the dataset to be removed from.
581+
"""
582+
request = RemoveBinaryDataFromDatasetByIDsRequest(binary_ids=binary_ids, dataset_id=dataset_id)
583+
await self._data_client.RemoveBinaryDataFromDatasetByIDs(request, metadata=self._metadata)
584+
481585
async def binary_data_capture_upload(
482586
self,
483587
binary_data: bytes,
@@ -806,8 +910,9 @@ def create_filter(
806910
end_time: Optional[datetime] = None,
807911
tags: Optional[List[str]] = None,
808912
bbox_labels: Optional[List[str]] = None,
913+
dataset_id: Optional[str] = None,
809914
) -> Filter:
810-
warnings.warn("DataClient.create_filter is deprecated. Use AppClient.create_filter instead.", DeprecationWarning, stacklevel=2)
915+
warnings.warn("DataClient.create_filter is deprecated. Use utils.create_filter instead.", DeprecationWarning, stacklevel=2)
811916
return create_filter(
812917
component_name,
813918
component_type,
@@ -823,4 +928,5 @@ def create_filter(
823928
end_time,
824929
tags,
825930
bbox_labels,
931+
dataset_id,
826932
)

src/viam/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def create_filter(
284284
end_time: Optional[datetime] = None,
285285
tags: Optional[List[str]] = None,
286286
bbox_labels: Optional[List[str]] = None,
287+
dataset_id: Optional[str] = None,
287288
) -> Filter:
288289
"""Create a `Filter`.
289290
@@ -303,6 +304,7 @@ def create_filter(
303304
tags (Optional[List[str]]): Optional list of tags attached to the data being filtered (e.g., ["test"]).
304305
bbox_labels (Optional[List[str]]): Optional list of bounding box labels attached to the data being filtered (e.g., ["square",
305306
"circle"]).
307+
dataset_id (Optional[str]): Optional ID of dataset associated with data being filtered
306308
307309
Returns:
308310
viam.proto.app.data.Filter: The `Filter` object.
@@ -328,4 +330,5 @@ def create_filter(
328330
else None,
329331
tags_filter=TagsFilter(tags=tags),
330332
bbox_labels=bbox_labels,
333+
dataset_id=dataset_id if dataset_id else "",
331334
)

tests/mocks/services.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
1+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
22

33
import numpy as np
44
from grpclib.server import Stream
@@ -217,6 +217,20 @@
217217
TagsByFilterRequest,
218218
TagsByFilterResponse,
219219
)
220+
from viam.proto.app.dataset import (
221+
CreateDatasetRequest,
222+
CreateDatasetResponse,
223+
Dataset,
224+
DatasetServiceBase,
225+
DeleteDatasetRequest,
226+
DeleteDatasetResponse,
227+
ListDatasetsByIDsRequest,
228+
ListDatasetsByIDsResponse,
229+
ListDatasetsByOrganizationIDRequest,
230+
ListDatasetsByOrganizationIDResponse,
231+
RenameDatasetRequest,
232+
RenameDatasetResponse,
233+
)
220234
from viam.proto.app.datasync import (
221235
DataCaptureUploadRequest,
222236
DataCaptureUploadResponse,
@@ -822,12 +836,20 @@ async def ConfigureDatabaseUser(self, stream: Stream[ConfigureDatabaseUserReques
822836
async def AddBinaryDataToDatasetByIDs(
823837
self, stream: Stream[AddBinaryDataToDatasetByIDsRequest, AddBinaryDataToDatasetByIDsResponse]
824838
) -> None:
825-
raise NotImplementedError()
839+
request = await stream.recv_message()
840+
assert request is not None
841+
self.added_data_ids = request.binary_ids
842+
self.dataset_id = request.dataset_id
843+
await stream.send_message(AddBinaryDataToDatasetByIDsResponse())
826844

827845
async def RemoveBinaryDataFromDatasetByIDs(
828846
self, stream: Stream[RemoveBinaryDataFromDatasetByIDsRequest, RemoveBinaryDataFromDatasetByIDsResponse]
829847
) -> None:
830-
raise NotImplementedError()
848+
request = await stream.recv_message()
849+
assert request is not None
850+
self.removed_data_ids = request.binary_ids
851+
self.dataset_id = request.dataset_id
852+
await stream.send_message(RemoveBinaryDataFromDatasetByIDsResponse())
831853

832854
async def TabularDataBySQL(self, stream: Stream[TabularDataBySQLRequest, TabularDataBySQLResponse]) -> None:
833855
raise NotImplementedError()
@@ -836,6 +858,46 @@ async def TabularDataByMQL(self, stream: Stream[TabularDataByMQLRequest, Tabular
836858
raise NotImplementedError()
837859

838860

861+
class MockDataset(DatasetServiceBase):
862+
def __init__(self, create_response: str, datasets_response: Sequence[Dataset]):
863+
self.create_response = create_response
864+
self.datasets_response = datasets_response
865+
866+
async def CreateDataset(self, stream: Stream[CreateDatasetRequest, CreateDatasetResponse]) -> None:
867+
request = await stream.recv_message()
868+
assert request is not None
869+
self.name = request.name
870+
self.org_id = request.organization_id
871+
await stream.send_message(CreateDatasetResponse(id=self.create_response))
872+
873+
async def DeleteDataset(self, stream: Stream[DeleteDatasetRequest, DeleteDatasetResponse]) -> None:
874+
request = await stream.recv_message()
875+
assert request is not None
876+
self.deleted_id = request.id
877+
await stream.send_message(DeleteDatasetResponse())
878+
879+
async def ListDatasetsByIDs(self, stream: Stream[ListDatasetsByIDsRequest, ListDatasetsByIDsResponse]) -> None:
880+
request = await stream.recv_message()
881+
assert request is not None
882+
self.ids = request.ids
883+
await stream.send_message(ListDatasetsByIDsResponse(datasets=self.datasets_response))
884+
885+
async def ListDatasetsByOrganizationID(
886+
self, stream: Stream[ListDatasetsByOrganizationIDRequest, ListDatasetsByOrganizationIDResponse]
887+
) -> None:
888+
request = await stream.recv_message()
889+
assert request is not None
890+
self.org_id = request.organization_id
891+
await stream.send_message(ListDatasetsByOrganizationIDResponse(datasets=self.datasets_response))
892+
893+
async def RenameDataset(self, stream: Stream[RenameDatasetRequest, RenameDatasetResponse]) -> None:
894+
request = await stream.recv_message()
895+
assert request is not None
896+
self.id = request.id
897+
self.name = request.name
898+
await stream.send_message((RenameDatasetResponse()))
899+
900+
839901
class MockDataSync(DataSyncServiceBase):
840902
def __init__(self, file_upload_response: str):
841903
self.file_upload_response = file_upload_response

tests/test_data_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
TAGS = ["tag"]
3737
BBOX_LABEL = "bbox_label"
3838
BBOX_LABELS = [BBOX_LABEL]
39+
DATASET_ID = "VIAM_DATASET_1"
3940
FILTER = create_filter(
4041
component_name=COMPONENT_NAME,
4142
component_type=COMPONENT_TYPE,
@@ -51,6 +52,7 @@
5152
end_time=END_DATETIME,
5253
tags=TAGS,
5354
bbox_labels=BBOX_LABELS,
55+
dataset_id=DATASET_ID,
5456
)
5557

5658
FILE_ID = "file_id"
@@ -248,6 +250,22 @@ async def test_get_database_connection(self, service: MockData):
248250
async def test_configure_database_user(self, service: MockData):
249251
assert True
250252

253+
@pytest.mark.asyncio
254+
async def test_add_binary_data_to_dataset_by_ids(self, service: MockData):
255+
async with ChannelFor([service]) as channel:
256+
client = DataClient(channel, DATA_SERVICE_METADATA)
257+
await client.add_binary_data_to_dataset_by_ids(binary_ids=BINARY_IDS, dataset_id=DATASET_ID)
258+
assert service.added_data_ids == BINARY_IDS
259+
assert service.dataset_id == DATASET_ID
260+
261+
@pytest.mark.asyncio
262+
async def test_remove_binary_data_to_dataset_by_ids(self, service: MockData):
263+
async with ChannelFor([service]) as channel:
264+
client = DataClient(channel, DATA_SERVICE_METADATA)
265+
await client.remove_binary_data_from_dataset_by_ids(binary_ids=BINARY_IDS, dataset_id=DATASET_ID)
266+
assert service.removed_data_ids == BINARY_IDS
267+
assert service.dataset_id == DATASET_ID
268+
251269
def assert_filter(self, filter: Filter) -> None:
252270
assert filter.component_name == COMPONENT_NAME
253271
assert filter.component_type == COMPONENT_TYPE

tests/test_dataset.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
from google.protobuf.timestamp_pb2 import Timestamp
3+
from grpclib.testing import ChannelFor
4+
5+
from viam.app.data_client import DataClient
6+
from viam.proto.app.dataset import Dataset
7+
8+
from .mocks.services import MockDataset
9+
10+
CREATED_ID = "VIAM_DATASET_0"
11+
ID = "VIAM_DATASET_1"
12+
NAME = "dataset"
13+
ORG_ID = "org_id"
14+
SECONDS = 978310861
15+
NANOS = 0
16+
TIME = Timestamp(seconds=SECONDS, nanos=NANOS)
17+
DATASET = Dataset(id=ID, name=NAME, organization_id=ORG_ID, time_created=TIME)
18+
DATASETS = [DATASET]
19+
AUTH_TOKEN = "auth_token"
20+
DATA_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"}
21+
22+
23+
@pytest.fixture(scope="function")
24+
def service() -> MockDataset:
25+
return MockDataset(CREATED_ID, DATASETS)
26+
27+
28+
class TestClient:
29+
@pytest.mark.asyncio
30+
async def test_create_dataset(self, service: MockDataset):
31+
async with ChannelFor([service]) as channel:
32+
client = DataClient(channel, DATA_SERVICE_METADATA)
33+
id = await client.create_dataset(NAME, ORG_ID)
34+
assert service.name == NAME
35+
assert service.org_id == ORG_ID
36+
assert id == CREATED_ID
37+
38+
@pytest.mark.asyncio
39+
async def test_delete_dataset(self, service: MockDataset):
40+
async with ChannelFor([service]) as channel:
41+
client = DataClient(channel, DATA_SERVICE_METADATA)
42+
await client.delete_dataset(ID)
43+
assert service.deleted_id == ID
44+
45+
@pytest.mark.asyncio
46+
async def test_list_datasets_by_ids(self, service: MockDataset):
47+
async with ChannelFor([service]) as channel:
48+
client = DataClient(channel, DATA_SERVICE_METADATA)
49+
datasets = await client.list_dataset_by_ids([ID])
50+
assert service.ids == [ID]
51+
assert datasets == DATASETS
52+
53+
@pytest.mark.asyncio
54+
async def test_list_datasets_by_organization_id(self, service: MockDataset):
55+
async with ChannelFor([service]) as channel:
56+
client = DataClient(channel, DATA_SERVICE_METADATA)
57+
datasets = await client.list_datasets_by_organization_id(ORG_ID)
58+
assert service.org_id == ORG_ID
59+
assert datasets == DATASETS

0 commit comments

Comments
 (0)