Skip to content

Commit 234ef2b

Browse files
authored
RSDK-594 receive op id in python sdk (#188)
1 parent be0de8a commit 234ef2b

File tree

3 files changed

+46
-10
lines changed

3 files changed

+46
-10
lines changed

src/viam/components/arm/service.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from grpclib.server import Stream
2+
23
from viam.components.service_base import ComponentServiceBase
34
from viam.errors import ComponentNotFoundError
45
from viam.proto.component.arm import (
@@ -35,7 +36,7 @@ async def GetEndPosition(self, stream: Stream[GetEndPositionRequest, GetEndPosit
3536
except ComponentNotFoundError as e:
3637
raise e.grpc_error
3738
timeout = stream.deadline.time_remaining() if stream.deadline else None
38-
position = await arm.get_end_position(extra=struct_to_dict(request.extra), timeout=timeout)
39+
position = await arm.get_end_position(extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
3940
response = GetEndPositionResponse(pose=position)
4041
await stream.send_message(response)
4142

@@ -48,7 +49,9 @@ async def MoveToPosition(self, stream: Stream[MoveToPositionRequest, MoveToPosit
4849
except ComponentNotFoundError as e:
4950
raise e.grpc_error
5051
timeout = stream.deadline.time_remaining() if stream.deadline else None
51-
await arm.move_to_position(request.to, request.world_state, extra=struct_to_dict(request.extra), timeout=timeout)
52+
await arm.move_to_position(
53+
request.to, request.world_state, extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata
54+
)
5255
response = MoveToPositionResponse()
5356
await stream.send_message(response)
5457

@@ -61,7 +64,7 @@ async def GetJointPositions(self, stream: Stream[GetJointPositionsRequest, GetJo
6164
except ComponentNotFoundError as e:
6265
raise e.grpc_error
6366
timeout = stream.deadline.time_remaining() if stream.deadline else None
64-
positions = await arm.get_joint_positions(extra=struct_to_dict(request.extra), timeout=timeout)
67+
positions = await arm.get_joint_positions(extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
6568
response = GetJointPositionsResponse(positions=positions)
6669
await stream.send_message(response)
6770

@@ -74,7 +77,7 @@ async def MoveToJointPositions(self, stream: Stream[MoveToJointPositionsRequest,
7477
except ComponentNotFoundError as e:
7578
raise e.grpc_error
7679
timeout = stream.deadline.time_remaining() if stream.deadline else None
77-
await arm.move_to_joint_positions(request.positions, extra=struct_to_dict(request.extra), timeout=timeout)
80+
await arm.move_to_joint_positions(request.positions, extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
7881
response = MoveToJointPositionsResponse()
7982
await stream.send_message(response)
8083

@@ -87,6 +90,6 @@ async def Stop(self, stream: Stream[StopRequest, StopResponse]) -> None:
8790
except ComponentNotFoundError as e:
8891
raise e.grpc_error
8992
timeout = stream.deadline.time_remaining() if stream.deadline else None
90-
await arm.stop(extra=struct_to_dict(request.extra), timeout=timeout)
93+
await arm.stop(extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
9194
response = StopResponse()
9295
await stream.send_message(response)

src/viam/operations.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import functools
33
import time
4-
from typing import Any, Callable, Coroutine, Optional, TypeVar, cast
4+
from typing import Any, Callable, Coroutine, Mapping, Optional, TypeVar, cast
55
from uuid import UUID, uuid4
66

77
from typing_extensions import Self
@@ -27,8 +27,8 @@ class Operation:
2727
_cancel_event: asyncio.Event
2828
_cancelled: bool
2929

30-
def __init__(self, method: str, cancel_event: asyncio.Event) -> None:
31-
self.id = uuid4()
30+
def __init__(self, method: str, cancel_event: asyncio.Event, opid: Optional[UUID] = None) -> None:
31+
self.id = uuid4() if opid is None else opid
3232
self.method = method
3333
self.time_started = time.time()
3434
self._cancel_event = cancel_event
@@ -61,6 +61,19 @@ def _noop(cls) -> Self:
6161
P = ParamSpec("P")
6262
T = TypeVar("T")
6363

64+
METADATA_KEY = "opid"
65+
66+
67+
def opid_from_metadata(metadata: Optional[Mapping[str, str]]) -> Optional[UUID]:
68+
if metadata is None:
69+
return None
70+
71+
opid = metadata.get(METADATA_KEY)
72+
if opid is None:
73+
return None
74+
75+
return UUID(opid)
76+
6477

6578
def run_with_operation(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
6679
"""Run a component function with an ``Operation``.
@@ -89,7 +102,9 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
89102
func_name = func.__qualname__
90103
arg_names = ", ".join([str(a) for a in args])
91104
kwarg_names = ", ".join([f"{key}={value}" for (key, value) in kwargs.items()])
92-
operation = Operation(f"{func_name}({arg_names}{', ' if len(arg_names) else ''}{kwarg_names})", event)
105+
method = f"{func_name}({arg_names}{', ' if len(arg_names) else ''}{kwarg_names})"
106+
opid = opid_from_metadata(kwargs.get("metadata")) # type: ignore
107+
operation = Operation(method, event, opid=opid)
93108
kwargs[Operation.ARG_NAME] = operation
94109
timeout = kwargs.get("timeout", None)
95110
timer: Optional[asyncio.TimerHandle] = None

tests/test_operations.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
22
import time
3+
from uuid import UUID
34

45
import pytest
56

6-
from viam.operations import Operation, run_with_operation
7+
from viam.operations import METADATA_KEY, Operation, run_with_operation
78

89

910
@pytest.mark.asyncio
@@ -57,3 +58,20 @@ async def long_running(self, **kwargs) -> bool:
5758
test_obj.long_running_task_cancelled = False
5859
assert test_obj.long_running_task_cancelled is False
5960
assert await asyncio.create_task(test_obj.long_running(timeout=0.02)) is True
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_wrapper_with_metadata():
65+
test_metadata_opid = "11111111-1111-1111-1111-111111111111"
66+
67+
class TestWrapperClass:
68+
@run_with_operation
69+
async def run(self, **kwargs) -> bool:
70+
operation: Operation = kwargs.get(Operation.ARG_NAME, Operation._noop())
71+
assert operation.id == UUID(test_metadata_opid)
72+
return False
73+
74+
test_obj = TestWrapperClass()
75+
metadata = {METADATA_KEY: test_metadata_opid}
76+
result = await test_obj.run(metadata=metadata)
77+
assert result is False

0 commit comments

Comments
 (0)