Skip to content

Commit 8e9b4c9

Browse files
authored
[RSDK-9328] Exclude non-actuating methods (viamrobotics#857)
1 parent 718c881 commit 8e9b4c9

File tree

3 files changed

+64
-17
lines changed

3 files changed

+64
-17
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"grpclib>=0.4.7",
1616
"protobuf==5.29.2",
1717
"typing-extensions>=4.12.2",
18-
"pymongo>=4.10.1"
18+
"pymongo>=4.10.1",
1919
]
2020

2121
[project.urls]

src/viam/sessions_client.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
2+
import importlib
3+
import pkgutil
24
from copy import deepcopy
35
from datetime import timedelta
46
from enum import IntEnum
57
from threading import Lock, Thread
6-
from typing import Optional
8+
from typing import MutableMapping, Optional
79

810
from grpclib import Status
911
from grpclib.client import Channel
@@ -12,26 +14,13 @@
1214
from grpclib.metadata import _MetadataLike
1315

1416
from viam import logging
17+
from viam.gen.common.v1.common_pb2 import safety_heartbeat_monitored
1518
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse
1619
from viam.rpc.dial import DialOptions, dial
1720

1821
LOGGER = logging.getLogger(__name__)
1922
SESSION_METADATA_KEY = "viam-sid"
2023

21-
EXEMPT_METADATA_METHODS = frozenset(
22-
[
23-
"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
24-
"/proto.rpc.webrtc.v1.SignalingService/Call",
25-
"/proto.rpc.webrtc.v1.SignalingService/CallUpdate",
26-
"/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig",
27-
"/proto.rpc.v1.AuthService/Authenticate",
28-
"/viam.robot.v1.RobotService/ResourceNames",
29-
"/viam.robot.v1.RobotService/ResourceRPCSubtypes",
30-
"/viam.robot.v1.RobotService/StartSession",
31-
"/viam.robot.v1.RobotService/SendSessionHeartbeat",
32-
]
33-
)
34-
3524

3625
class _SupportedState(IntEnum):
3726
UNKNOWN = 0
@@ -56,6 +45,8 @@ class SessionsClient:
5645
_supported: _SupportedState
5746
_thread: Optional[Thread]
5847

48+
_HEARTBEAT_MONITORED_METHODS: MutableMapping[str, bool] = {}
49+
5950
def __init__(self, channel: Channel, direct_dial_address: str, dial_options: Optional[DialOptions], *, disabled: bool = False):
6051
self.channel = channel
6152
self.client = RobotServiceStub(channel)
@@ -92,7 +83,7 @@ async def _send_request(self, event: SendRequest):
9283
if self._disabled:
9384
return
9485

95-
if event.method_name in EXEMPT_METADATA_METHODS:
86+
if not self._is_safety_heartbeat_monitored(event.method_name):
9687
return
9788

9889
event.metadata.update(await self.metadata)
@@ -183,3 +174,49 @@ def _metadata(self) -> _MetadataLike:
183174
return {SESSION_METADATA_KEY: self._current_id}
184175

185176
return {}
177+
178+
def _is_safety_heartbeat_monitored(self, method: str) -> bool:
179+
if method in self._HEARTBEAT_MONITORED_METHODS:
180+
return self._HEARTBEAT_MONITORED_METHODS[method]
181+
182+
parts = method.split("/")
183+
if len(parts) != 3:
184+
self._HEARTBEAT_MONITORED_METHODS[method] = False
185+
return False
186+
service_path = parts[1]
187+
method_name = parts[2]
188+
189+
parts = service_path.split(".")
190+
if len(parts) < 5:
191+
self._HEARTBEAT_MONITORED_METHODS[method] = False
192+
return False
193+
if parts[0] != "viam":
194+
self._HEARTBEAT_MONITORED_METHODS[method] = False
195+
return False
196+
resource_type = parts[1]
197+
resource_subtype = parts[2]
198+
version = parts[3]
199+
service_name = parts[4]
200+
try:
201+
module = importlib.import_module(f"viam.gen.{resource_type}.{resource_subtype}.{version}")
202+
submods = pkgutil.iter_modules(module.__path__)
203+
for mod in submods:
204+
if "_pb2" in mod.name:
205+
submod = getattr(module, mod.name)
206+
DESCRIPTOR = getattr(submod, "DESCRIPTOR")
207+
for service in DESCRIPTOR.services_by_name.values():
208+
if service.name == service_name:
209+
for method_actual in service.methods:
210+
if method_actual.name == method_name:
211+
options = method_actual.GetOptions()
212+
if options.HasExtension(safety_heartbeat_monitored):
213+
is_monitored = options.Extensions[safety_heartbeat_monitored]
214+
self._HEARTBEAT_MONITORED_METHODS[method] = is_monitored
215+
return is_monitored
216+
self._HEARTBEAT_MONITORED_METHODS[method] = False
217+
return False
218+
self._HEARTBEAT_MONITORED_METHODS[method] = False
219+
return False
220+
except Exception:
221+
self._HEARTBEAT_MONITORED_METHODS[method] = False
222+
return False

tests/test_sessions_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,13 @@ async def test_sessions_disabled(service: MockRobot):
123123
assert await client.metadata == {}
124124
assert client._supported == _SupportedState.UNKNOWN
125125
assert not client._heartbeat_interval
126+
127+
128+
async def test_safete_heartbeat_monitored():
129+
async with ChannelFor([]) as channel:
130+
client = SessionsClient(channel, "", None, disabled=True)
131+
is_monitored = client._is_safety_heartbeat_monitored("/viam.component.arm.v1.ArmService/MoveToPosition")
132+
assert is_monitored is True
133+
134+
is_monitored = client._is_safety_heartbeat_monitored("/viam.component.camera.v1.CameraService/GetImage")
135+
assert is_monitored is False

0 commit comments

Comments
 (0)