Skip to content

Commit 92305d3

Browse files
Merge pull request #1792 from roboflow/fix/webrtc-worker-do-not-return-in-finally
When encapsulating webrtc worker within exception block, do not return in finally
2 parents 1d134f1 + 2cd8232 commit 92305d3

File tree

3 files changed

+56
-19
lines changed

3 files changed

+56
-19
lines changed

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,12 @@ async def run_rtc_peer_connection_with_watchdog(
204204
webrtc_request: WebRTCWorkerRequest,
205205
send_answer: Callable[[WebRTCWorkerResult], None],
206206
model_manager: ModelManager,
207-
) -> bool:
207+
watchdog: Watchdog,
208+
):
208209
from inference.core.interfaces.webrtc_worker.webrtc import (
209210
init_rtc_peer_connection_with_loop,
210211
)
211212

212-
watchdog = Watchdog(
213-
api_key=webrtc_request.api_key,
214-
timeout_seconds=WEBRTC_MODAL_WATCHDOG_TIMEMOUT,
215-
)
216-
217213
rtc_peer_connection_task = asyncio.create_task(
218214
init_rtc_peer_connection_with_loop(
219215
webrtc_request=webrtc_request,
@@ -238,12 +234,11 @@ def on_timeout(message: Optional[str] = ""):
238234
except modal.exception.InputCancellation:
239235
logger.warning("Modal function was cancelled")
240236
except asyncio.CancelledError as exc:
241-
logger.info("WebRTC connection task was cancelled (%s)", exc)
237+
logger.warning("WebRTC connection task was cancelled (%s)", exc)
242238
except Exception as exc:
243239
logger.error(exc)
244240
finally:
245241
watchdog.stop()
246-
return watchdog.heartbeat_occurred
247242

248243
class RTCPeerConnectionModal:
249244
_model_manager: Optional[ModelManager] = modal.parameter(
@@ -338,6 +333,10 @@ def rtc_peer_connection_modal(
338333

339334
def send_answer(obj: WebRTCWorkerResult):
340335
logger.info("Sending webrtc answer")
336+
if obj.error_message:
337+
logger.error(
338+
"Error: %s (%s)", obj.error_message, obj.exception_type
339+
)
341340
# Queue with no limit, below will never block
342341
q.put(obj)
343342

@@ -356,23 +355,35 @@ def send_answer(obj: WebRTCWorkerResult):
356355
send_answer(WebRTCWorkerResult(error_message=error_msg))
357356
return
358357

358+
watchdog = Watchdog(
359+
api_key=webrtc_request.api_key,
360+
timeout_seconds=WEBRTC_MODAL_WATCHDOG_TIMEMOUT,
361+
)
362+
359363
try:
360-
heartbeat_occurred = asyncio.run(
364+
asyncio.run(
361365
run_rtc_peer_connection_with_watchdog(
362366
webrtc_request=webrtc_request,
363367
send_answer=send_answer,
364368
model_manager=self._model_manager,
369+
watchdog=watchdog,
365370
)
366371
)
367-
except (asyncio.CancelledError, modal.exception.InputCancellation):
372+
except modal.exception.InputCancellation:
368373
logger.warning("Modal function was cancelled")
374+
except asyncio.CancelledError as exc:
375+
logger.warning("WebRTC connection task was cancelled (%s)", exc)
376+
except Exception as exc:
377+
logger.warning("Unhandled exception: %s", exc)
378+
finally:
379+
watchdog.stop()
369380

370381
_exec_session_stopped = datetime.datetime.now()
371382
logger.info(
372383
"WebRTC session stopped at %s",
373384
_exec_session_stopped.isoformat(),
374385
)
375-
if not heartbeat_occurred:
386+
if watchdog.total_heartbeats == 0:
376387
raise Exception(
377388
"WebRTC worker was terminated before processing a single frame"
378389
)

inference/core/interfaces/webrtc_worker/watchdog.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ def __init__(
2424
self._last_log_ts = datetime.datetime.now()
2525
self._log_interval_seconds = 10
2626
self._heartbeats = 0
27-
self.heartbeat_occurred = False
27+
self._total_heartbeats = 0
28+
29+
@property
30+
def total_heartbeats(self) -> int:
31+
return self._total_heartbeats
2832

2933
def start(self):
3034
logger.info("Starting watchdog with timeout %s", self.timeout_seconds)
@@ -44,25 +48,27 @@ def _watchdog_thread(self):
4448
while not self._stopping:
4549
if not self.is_alive():
4650
logger.error(
47-
"Watchdog timeout reached, heartbeats: %s", self._heartbeats
51+
"Watchdog timeout reached, heartbeats: %s", self._total_heartbeats
4852
)
4953
self.on_timeout(
50-
message=f"Timeout reached, heartbeats: {self._heartbeats}"
54+
message=f"Timeout reached, heartbeats: {self._total_heartbeats}"
5155
)
5256
break
5357
if WEBRTC_MODAL_USAGE_QUOTA_ENABLED and is_over_quota(self._api_key):
54-
logger.error("API key over quota, heartbeats: %s", self._heartbeats)
58+
logger.error(
59+
"API key over quota, heartbeats: %s", self._total_heartbeats
60+
)
5561
self.on_timeout(
56-
message=f"API key over quota, heartbeats: {self._heartbeats}"
62+
message=f"API key over quota, heartbeats: {self._total_heartbeats}"
5763
)
5864
break
5965
time.sleep(1)
60-
logger.info("Watchdog thread stopped, heartbeats: %s", self._heartbeats)
66+
logger.info("Watchdog thread stopped, heartbeats: %s", self._total_heartbeats)
6167

6268
def heartbeat(self):
6369
self.last_heartbeat = datetime.datetime.now()
6470
self._heartbeats += 1
65-
self.heartbeat_occurred = True
71+
self._total_heartbeats += 1
6672
if (
6773
datetime.datetime.now() - self._last_log_ts
6874
).total_seconds() > self._log_interval_seconds:

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from inference.core.workflows.core_steps.common.serializers import (
6464
serialize_wildcard_kind,
6565
)
66-
from inference.core.workflows.errors import WorkflowSyntaxError
66+
from inference.core.workflows.errors import WorkflowError, WorkflowSyntaxError
6767
from inference.core.workflows.execution_engine.entities.base import WorkflowImageData
6868
from inference.usage_tracking.collector import usage_collector
6969

@@ -927,6 +927,8 @@ async def init_rtc_peer_connection_with_loop(
927927
KeyError,
928928
NotImplementedError,
929929
) as error:
930+
# heartbeat to indicate caller error
931+
heartbeat_callback()
930932
send_answer(
931933
WebRTCWorkerResult(
932934
exception_type=error.__class__.__name__,
@@ -935,6 +937,8 @@ async def init_rtc_peer_connection_with_loop(
935937
)
936938
return
937939
except WebRTCConfigurationError as error:
940+
# heartbeat to indicate caller error
941+
heartbeat_callback()
938942
send_answer(
939943
WebRTCWorkerResult(
940944
exception_type=error.__class__.__name__,
@@ -943,6 +947,8 @@ async def init_rtc_peer_connection_with_loop(
943947
)
944948
return
945949
except RoboflowAPINotAuthorizedError:
950+
# heartbeat to indicate caller error
951+
heartbeat_callback()
946952
send_answer(
947953
WebRTCWorkerResult(
948954
exception_type=RoboflowAPINotAuthorizedError.__name__,
@@ -951,6 +957,8 @@ async def init_rtc_peer_connection_with_loop(
951957
)
952958
return
953959
except RoboflowAPINotNotFoundError:
960+
# heartbeat to indicate caller error
961+
heartbeat_callback()
954962
send_answer(
955963
WebRTCWorkerResult(
956964
exception_type=RoboflowAPINotNotFoundError.__name__,
@@ -959,6 +967,8 @@ async def init_rtc_peer_connection_with_loop(
959967
)
960968
return
961969
except WorkflowSyntaxError as error:
970+
# heartbeat to indicate caller error
971+
heartbeat_callback()
962972
send_answer(
963973
WebRTCWorkerResult(
964974
exception_type=WorkflowSyntaxError.__name__,
@@ -968,6 +978,16 @@ async def init_rtc_peer_connection_with_loop(
968978
)
969979
)
970980
return
981+
except WorkflowError as error:
982+
# heartbeat to indicate caller error
983+
heartbeat_callback()
984+
send_answer(
985+
WebRTCWorkerResult(
986+
exception_type=WorkflowError.__name__,
987+
error_message=str(error),
988+
)
989+
)
990+
return
971991
except Exception as error:
972992
send_answer(
973993
WebRTCWorkerResult(

0 commit comments

Comments
 (0)