Skip to content

Commit 7b34a10

Browse files
RSDK-7189: add data wrappers (#603)
1 parent 7acf717 commit 7b34a10

File tree

3 files changed

+93
-8
lines changed

3 files changed

+93
-8
lines changed

src/viam/app/data_client.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from datetime import datetime
44
from pathlib import Path
5-
from typing import Any, List, Mapping, Optional, Sequence, Tuple
5+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
66

77
from google.protobuf.struct_pb2 import Struct
88
from grpclib.client import Channel, Stream
@@ -23,6 +23,7 @@
2323
BoundingBoxLabelsByFilterRequest,
2424
BoundingBoxLabelsByFilterResponse,
2525
CaptureMetadata,
26+
ConfigureDatabaseUserRequest,
2627
DataRequest,
2728
DataServiceStub,
2829
DeleteBinaryDataByFilterRequest,
@@ -43,6 +44,10 @@
4344
RemoveTagsFromBinaryDataByIDsResponse,
4445
TabularDataByFilterRequest,
4546
TabularDataByFilterResponse,
47+
TabularDataByMQLRequest,
48+
TabularDataByMQLResponse,
49+
TabularDataBySQLRequest,
50+
TabularDataBySQLResponse,
4651
TagsByFilterRequest,
4752
TagsByFilterResponse,
4853
)
@@ -73,7 +78,7 @@
7378
StreamingDataCaptureUploadResponse,
7479
UploadMetadata,
7580
)
76-
from viam.utils import create_filter, datetime_to_timestamp, struct_to_dict
81+
from viam.utils import ValueTypes, create_filter, datetime_to_timestamp, struct_to_dict
7782

7883
LOGGER = logging.getLogger(__name__)
7984

@@ -248,6 +253,44 @@ async def tabular_data_by_filter(
248253
LOGGER.error(f"Failed to write tabular data to file {dest}", exc_info=e)
249254
return data, response.count, response.last
250255

256+
async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> List[Dict[str, ValueTypes]]:
257+
"""Obtain unified tabular data and metadata, queried with SQL.
258+
259+
::
260+
261+
data = await data_client.tabular_data_by_sql(org_id="<your-org-id>", sql_query="<sql-query>")
262+
263+
264+
Args:
265+
organization_id (str): The ID of the organization that owns the data.
266+
sql_query (str): The SQL query to run.
267+
268+
Returns:
269+
List[Dict[str, ValueTypes]]: An array of data objects.
270+
"""
271+
request = TabularDataBySQLRequest(organization_id=organization_id, sql_query=sql_query)
272+
response: TabularDataBySQLResponse = await self._data_client.TabularDataBySQL(request, metadata=self._metadata)
273+
return [struct_to_dict(struct) for struct in response.data]
274+
275+
async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes]) -> List[Dict[str, ValueTypes]]:
276+
"""Obtain unified tabular data and metadata, queried with MQL.
277+
278+
::
279+
280+
data = await data_client.tabular_data_by_mql(org_id="<your-org-id>", mql_binary=[<mql-bytes-1>, <mql-bytes-2>])
281+
282+
283+
Args:
284+
organization_id (str): The ID of the organization that owns the data.
285+
mql_binary (List[bytes]):The MQL query to run as a list of BSON documents.
286+
287+
Returns:
288+
List[Dict[str, ValueTypes]]: An array of data objects.
289+
"""
290+
request = TabularDataByMQLRequest(organization_id=organization_id, mql_binary=mql_binary)
291+
response: TabularDataByMQLResponse = await self._data_client.TabularDataByMQL(request, metadata=self._metadata)
292+
return [struct_to_dict(struct) for struct in response.data]
293+
251294
async def binary_data_by_filter(
252295
self,
253296
filter: Optional[Filter] = None,
@@ -733,9 +776,16 @@ async def get_database_connection(self, organization_id: str) -> str:
733776
response: GetDatabaseConnectionResponse = await self._data_client.GetDatabaseConnection(request, metadata=self._metadata)
734777
return response.hostname
735778

736-
# TODO(RSDK-5569): implement
737779
async def configure_database_user(self, organization_id: str, password: str) -> None:
738-
raise NotImplementedError()
780+
"""Configure a database user for the Viam organization's MongoDB Atlas Data Federation instance. It can also be used to reset the
781+
password of the existing database user.
782+
783+
Args:
784+
organization_id (str): The ID of the organization.
785+
password (str): The password of the user.
786+
"""
787+
request = ConfigureDatabaseUserRequest(organization_id=organization_id, password=password)
788+
await self._data_client.ConfigureDatabaseUser(request, metadata=self._metadata)
739789

740790
async def create_dataset(self, name: str, organization_id: str) -> str:
741791
"""Create a new dataset.

tests/mocks/services.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,13 +753,15 @@ class MockData(DataServiceBase):
753753
def __init__(
754754
self,
755755
tabular_response: List[DataClient.TabularData],
756+
tabular_query_response: List[Dict[str, ValueTypes]],
756757
binary_response: List[DataClient.BinaryData],
757758
delete_remove_response: int,
758759
tags_response: List[str],
759760
bbox_labels_response: List[str],
760761
hostname_response: str,
761762
):
762763
self.tabular_response = tabular_response
764+
self.tabular_query_response = tabular_query_response
763765
self.binary_response = binary_response
764766
self.delete_remove_response = delete_remove_response
765767
self.tags_response = tags_response
@@ -916,7 +918,11 @@ async def GetDatabaseConnection(self, stream: Stream[GetDatabaseConnectionReques
916918
await stream.send_message(GetDatabaseConnectionResponse(hostname=self.hostname_response))
917919

918920
async def ConfigureDatabaseUser(self, stream: Stream[ConfigureDatabaseUserRequest, ConfigureDatabaseUserResponse]) -> None:
919-
raise NotImplementedError()
921+
request = await stream.recv_message()
922+
assert request is not None
923+
self.organization_id = request.organization_id
924+
self.password = request.password
925+
await stream.send_message(ConfigureDatabaseUserResponse())
920926

921927
async def AddBinaryDataToDatasetByIDs(
922928
self, stream: Stream[AddBinaryDataToDatasetByIDsRequest, AddBinaryDataToDatasetByIDsResponse]
@@ -937,10 +943,14 @@ async def RemoveBinaryDataFromDatasetByIDs(
937943
await stream.send_message(RemoveBinaryDataFromDatasetByIDsResponse())
938944

939945
async def TabularDataBySQL(self, stream: Stream[TabularDataBySQLRequest, TabularDataBySQLResponse]) -> None:
940-
raise NotImplementedError()
946+
request = await stream.recv_message()
947+
assert request is not None
948+
await stream.send_message(TabularDataBySQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))
941949

942950
async def TabularDataByMQL(self, stream: Stream[TabularDataByMQLRequest, TabularDataByMQLResponse]) -> None:
943-
raise NotImplementedError()
951+
request = await stream.recv_message()
952+
assert request is not None
953+
await stream.send_message(TabularDataByMQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))
944954

945955

946956
class MockDataset(DatasetServiceBase):

tests/test_data_client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
LOCATION_IDS = [LOCATION_ID]
2323
ORG_ID = "organization_id"
2424
ORG_IDS = [ORG_ID]
25+
PASSWORD = "password"
2526
MIME_TYPE = "mime_type"
2627
MIME_TYPES = [MIME_TYPE]
2728
URI = "some.robot.uri"
@@ -70,6 +71,8 @@
7071
y_max_normalized=0.3,
7172
)
7273
BBOXES = [BBOX]
74+
SQL_QUERY = "sql_query"
75+
MQL_BINARY = [b"mql_binary"]
7376
TABULAR_DATA = {"key": "value"}
7477
TABULAR_METADATA = CaptureMetadata(
7578
organization_id=ORG_ID,
@@ -97,6 +100,9 @@
97100
)
98101

99102
TABULAR_RESPONSE = [DataClient.TabularData(TABULAR_DATA, TABULAR_METADATA, START_DATETIME, END_DATETIME)]
103+
TABULAR_QUERY_RESPONSE = [
104+
{"key1": 1, "key2": "2", "key3": [1, 2, 3], "key4": {"key4sub1": 1}},
105+
]
100106
BINARY_RESPONSE = [DataClient.BinaryData(BINARY_DATA, BINARY_METADATA)]
101107
DELETE_REMOVE_RESPONSE = 1
102108
TAGS_RESPONSE = ["tag"]
@@ -110,6 +116,7 @@
110116
def service() -> MockData:
111117
return MockData(
112118
tabular_response=TABULAR_RESPONSE,
119+
tabular_query_response=TABULAR_QUERY_RESPONSE,
113120
binary_response=BINARY_RESPONSE,
114121
delete_remove_response=DELETE_REMOVE_RESPONSE,
115122
tags_response=TAGS_RESPONSE,
@@ -143,6 +150,20 @@ async def test_tabular_data_by_filter(self, service: MockData):
143150
assert last_response != ""
144151
self.assert_filter(filter=service.filter)
145152

153+
@pytest.mark.asyncio
154+
async def test_tabular_data_by_sql(self, service: MockData):
155+
async with ChannelFor([service]) as channel:
156+
client = DataClient(channel, DATA_SERVICE_METADATA)
157+
response = await client.tabular_data_by_sql(ORG_ID, SQL_QUERY)
158+
assert response == TABULAR_QUERY_RESPONSE
159+
160+
@pytest.mark.asyncio
161+
async def test_tabular_data_by_mql(self, service: MockData):
162+
async with ChannelFor([service]) as channel:
163+
client = DataClient(channel, DATA_SERVICE_METADATA)
164+
response = await client.tabular_data_by_mql(ORG_ID, MQL_BINARY)
165+
assert response == TABULAR_QUERY_RESPONSE
166+
146167
@pytest.mark.asyncio
147168
async def test_binary_data_by_filter(self, service: MockData):
148169
async with ChannelFor([service]) as channel:
@@ -283,7 +304,11 @@ async def test_get_database_connection(self, service: MockData):
283304

284305
@pytest.mark.asyncio
285306
async def test_configure_database_user(self, service: MockData):
286-
assert True
307+
async with ChannelFor([service]) as channel:
308+
client = DataClient(channel, DATA_SERVICE_METADATA)
309+
await client.configure_database_user(ORG_ID, PASSWORD)
310+
assert service.organization_id == ORG_ID
311+
assert service.password == PASSWORD
287312

288313
@pytest.mark.asyncio
289314
async def test_add_binary_data_to_dataset_by_ids(self, service: MockData):

0 commit comments

Comments
 (0)