11import warnings
22from datetime import datetime
33from pathlib import Path
4- from typing import Any , List , Mapping , Optional , Tuple
4+ from typing import Any , List , Mapping , Optional , Sequence , Tuple
55
66from google .protobuf .struct_pb2 import Struct
77from grpclib .client import Channel , Stream
88
99from viam import logging
1010from viam .proto .app .data import (
11+ AddBinaryDataToDatasetByIDsRequest ,
1112 AddBoundingBoxToImageByIDRequest ,
1213 AddBoundingBoxToImageByIDResponse ,
1314 AddTagsToBinaryDataByFilterRequest ,
3233 Filter ,
3334 GetDatabaseConnectionRequest ,
3435 GetDatabaseConnectionResponse ,
36+ RemoveBinaryDataFromDatasetByIDsRequest ,
3537 RemoveBoundingBoxFromImageByIDRequest ,
3638 RemoveTagsFromBinaryDataByFilterRequest ,
3739 RemoveTagsFromBinaryDataByFilterResponse ,
4244 TagsByFilterRequest ,
4345 TagsByFilterResponse ,
4446)
47+ from viam .proto .app .dataset import (
48+ CreateDatasetRequest ,
49+ CreateDatasetResponse ,
50+ Dataset ,
51+ DatasetServiceStub ,
52+ DeleteDatasetRequest ,
53+ ListDatasetsByIDsRequest ,
54+ ListDatasetsByIDsResponse ,
55+ ListDatasetsByOrganizationIDRequest ,
56+ ListDatasetsByOrganizationIDResponse ,
57+ RenameDatasetRequest ,
58+ )
4559from viam .proto .app .datasync import (
4660 DataCaptureUploadMetadata ,
4761 DataCaptureUploadRequest ,
@@ -98,6 +112,7 @@ def __eq__(self, other: object) -> bool:
98112 return str (self ) == str (other )
99113 return False
100114
115+ # TODO (RSDK-6684): Revisit if this shadow type is necessary
101116 class BinaryData :
102117 """Class representing a piece of binary data and associated metadata.
103118
@@ -131,10 +146,12 @@ def __init__(self, channel: Channel, metadata: Mapping[str, str]):
131146 self ._metadata = metadata
132147 self ._data_client = DataServiceStub (channel )
133148 self ._data_sync_client = DataSyncServiceStub (channel )
149+ self ._dataset_client = DatasetServiceStub (channel )
134150 self ._channel = channel
135151
136152 _data_client : DataServiceStub
137153 _data_sync_client : DataSyncServiceStub
154+ _dataset_client : DatasetServiceStub
138155 _metadata : Mapping [str , str ]
139156 _channel : Channel
140157
@@ -478,6 +495,93 @@ async def get_database_connection(self, organization_id: str) -> str:
478495 async def configure_database_user (self , organization_id : str , password : str ) -> None :
479496 raise NotImplementedError ()
480497
498+ async def create_dataset (self , name : str , organization_id : str ) -> str :
499+ """Create a new dataset.
500+
501+ Args:
502+ name (str): The name of the dataset being created.
503+ organization_id (str): The ID of the organization where the dataset is being created.
504+
505+ Returns:
506+ str: The dataset ID of the created dataset.
507+ """
508+ request = CreateDatasetRequest (name = name , organization_id = organization_id )
509+ response : CreateDatasetResponse = await self ._dataset_client .CreateDataset (request , metadata = self ._metadata )
510+ return response .id
511+
512+ async def list_dataset_by_ids (self , ids : List [str ]) -> Sequence [Dataset ]:
513+ """Get a list of datasets using their IDs.
514+
515+ Args:
516+ ids (List[str]): The IDs of the datasets being called for.
517+
518+ Returns:
519+ Sequence[Dataset]: The list of datasets.
520+ """
521+ request = ListDatasetsByIDsRequest (ids = ids )
522+ response : ListDatasetsByIDsResponse = await self ._dataset_client .ListDatasetsByIDs (request , metadata = self ._metadata )
523+
524+ return response .datasets
525+
526+ async def list_datasets_by_organization_id (self , organization_id : str ) -> Sequence [Dataset ]:
527+ """Get the datasets in an organization.
528+
529+ Args:
530+ organization_id (str): The ID of the organization.
531+
532+ Returns:
533+ Sequence[Dataset]: The list of datasets in the organization.
534+ """
535+ request = ListDatasetsByOrganizationIDRequest (organization_id = organization_id )
536+ response : ListDatasetsByOrganizationIDResponse = await self ._dataset_client .ListDatasetsByOrganizationID (
537+ request , metadata = self ._metadata
538+ )
539+
540+ return response .datasets
541+
542+ async def rename_dataset (self , id : str , name : str ) -> None :
543+ """Rename a dataset specified by the dataset ID.
544+
545+ Args:
546+ id (str): The ID of the dataset.
547+ name (str): The new name of the dataset.
548+ """
549+ request = RenameDatasetRequest (id = id , name = name )
550+ await self ._dataset_client .RenameDataset (request , metadata = self ._metadata )
551+
552+ async def delete_dataset (self , id : str ) -> None :
553+ """Delete a dataset.
554+
555+ Args:
556+ id (str): The ID of the dataset.
557+ """
558+ request = DeleteDatasetRequest (id = id )
559+ await self ._dataset_client .DeleteDataset (request , metadata = self ._metadata )
560+
561+ async def add_binary_data_to_dataset_by_ids (self , binary_ids : List [BinaryID ], dataset_id : str ) -> None :
562+ """Add the BinaryData to the provided dataset.
563+
564+ This BinaryData will be tagged with the VIAM_DATASET_{id} label.
565+
566+ Args:
567+ binary_ids (List[BinaryID]): The IDs of binary data to add to dataset.
568+ dataset_id (str): The ID of the dataset to be added to.
569+ """
570+ request = AddBinaryDataToDatasetByIDsRequest (binary_ids = binary_ids , dataset_id = dataset_id )
571+ await self ._data_client .AddBinaryDataToDatasetByIDs (request , metadata = self ._metadata )
572+
573+ async def remove_binary_data_from_dataset_by_ids (self , binary_ids : List [BinaryID ], dataset_id : str ) -> None :
574+ """Remove the BinaryData from the provided dataset.
575+
576+ This BinaryData will lose the VIAM_DATASET_{id} tag.
577+
578+ Args:
579+ binary_ids (List[BinaryID]): The IDs of binary data to remove from dataset.
580+ dataset_id (str): The ID of the dataset to be removed from.
581+ """
582+ request = RemoveBinaryDataFromDatasetByIDsRequest (binary_ids = binary_ids , dataset_id = dataset_id )
583+ await self ._data_client .RemoveBinaryDataFromDatasetByIDs (request , metadata = self ._metadata )
584+
481585 async def binary_data_capture_upload (
482586 self ,
483587 binary_data : bytes ,
@@ -806,8 +910,9 @@ def create_filter(
806910 end_time : Optional [datetime ] = None ,
807911 tags : Optional [List [str ]] = None ,
808912 bbox_labels : Optional [List [str ]] = None ,
913+ dataset_id : Optional [str ] = None ,
809914 ) -> Filter :
810- warnings .warn ("DataClient.create_filter is deprecated. Use AppClient .create_filter instead." , DeprecationWarning , stacklevel = 2 )
915+ warnings .warn ("DataClient.create_filter is deprecated. Use utils .create_filter instead." , DeprecationWarning , stacklevel = 2 )
811916 return create_filter (
812917 component_name ,
813918 component_type ,
@@ -823,4 +928,5 @@ def create_filter(
823928 end_time ,
824929 tags ,
825930 bbox_labels ,
931+ dataset_id ,
826932 )
0 commit comments