|
17 | 17 | import cv2 |
18 | 18 | import numpy as np |
19 | 19 | import requests |
| 20 | +from aiortc import RTCConfiguration, RTCIceServer |
20 | 21 |
|
21 | 22 | from inference_sdk.config import ( |
| 23 | + ALL_ROBOFLOW_API_URLS, |
| 24 | + RF_API_BASE_URL, |
22 | 25 | WEBRTC_EVENT_LOOP_SHUTDOWN_TIMEOUT, |
23 | 26 | WEBRTC_INITIAL_FRAME_TIMEOUT, |
24 | 27 | WEBRTC_VIDEO_QUEUE_MAX_SIZE, |
@@ -557,25 +560,72 @@ def _invoke_data_handler( |
557 | 560 | ) |
558 | 561 | raise |
559 | 562 |
|
560 | | - async def _get_turn_config(self) -> Optional[dict]: |
561 | | - """Get TURN configuration from user-provided config. |
| 563 | + @staticmethod |
| 564 | + def _to_list(value: Any) -> List[Any]: |
| 565 | + """Convert value to list if it is not already a list.""" |
| 566 | + if isinstance(value, list): |
| 567 | + return value |
| 568 | + return [value] |
| 569 | + |
| 570 | + async def _get_turn_config(self) -> Optional[RTCConfiguration]: |
| 571 | + """Get TURN configuration from user-provided config or Roboflow API. |
562 | 572 |
|
563 | 573 | Priority order: |
564 | 574 | 1. User-provided config via StreamConfig.turn_server (highest priority) |
565 | | - 2. Skip TURN for localhost connections |
566 | | - 3. Return None if not provided |
| 575 | + 2. Auto-fetch from Roboflow API for serverless connections |
| 576 | + 3. Return None for non-serverless connections |
567 | 577 |
|
568 | 578 | Returns: |
569 | 579 | TURN configuration dict or None |
570 | 580 | """ |
| 581 | + turn_config = None |
571 | 582 | # 1. Use user-provided config if available |
572 | 583 | if self._config.turn_server: |
| 584 | + turn_config = self._config.turn_server |
573 | 585 | logger.debug("Using user-provided TURN configuration") |
574 | | - return self._config.turn_server |
575 | 586 |
|
576 | | - # 3. No TURN config provided |
577 | | - logger.debug("No TURN configuration provided, proceeding without TURN server") |
578 | | - return None |
| 587 | + # 2. Auto-fetch from Roboflow API for Roboflow-hosted connections |
| 588 | + elif self._api_url in ALL_ROBOFLOW_API_URLS: |
| 589 | + try: |
| 590 | + logger.debug( |
| 591 | + "Fetching TURN config from Roboflow API for serverless connection" |
| 592 | + ) |
| 593 | + response = requests.get( |
| 594 | + f"{RF_API_BASE_URL}/webrtc_turn_config", |
| 595 | + params={"api_key": self._api_key}, |
| 596 | + timeout=5, |
| 597 | + ) |
| 598 | + response.raise_for_status() |
| 599 | + turn_config = response.json() |
| 600 | + logger.debug("Successfully fetched TURN config from Roboflow API") |
| 601 | + except Exception as e: |
| 602 | + logger.warning(f"Failed to fetch TURN config from Roboflow API: {e}") |
| 603 | + return None |
| 604 | + # standardize the TURN config to the iceServers format |
| 605 | + if turn_config and "iceServers" in turn_config: |
| 606 | + turn_config = RTCConfiguration( |
| 607 | + iceServers=[ |
| 608 | + RTCIceServer( |
| 609 | + urls=WebRTCSession._to_list(server.get("urls", [])), |
| 610 | + username=server.get("username"), |
| 611 | + credential=server.get("credential"), |
| 612 | + ) |
| 613 | + for server in turn_config["iceServers"] |
| 614 | + ] |
| 615 | + ) |
| 616 | + logger.debug("Successfully converted TURN config to iceServers format") |
| 617 | + elif turn_config and "urls" in turn_config: |
| 618 | + turn_config = RTCConfiguration( |
| 619 | + iceServers=[ |
| 620 | + RTCIceServer( |
| 621 | + urls=[turn_config["urls"]], |
| 622 | + username=turn_config["username"], |
| 623 | + credential=turn_config["credential"], |
| 624 | + ) |
| 625 | + ] |
| 626 | + ) |
| 627 | + logger.debug("Successfully converted TURN config to iceServers format") |
| 628 | + return turn_config |
579 | 629 |
|
580 | 630 | def _handle_datachannel_video_frame( |
581 | 631 | self, serialized_data: Any, metadata: Optional[VideoMetadata] |
@@ -627,17 +677,7 @@ async def _init(self) -> None: |
627 | 677 | # Fetch TURN configuration (auto-fetch or user-provided) |
628 | 678 | turn_config = await self._get_turn_config() |
629 | 679 |
|
630 | | - # Create peer connection with TURN config if available |
631 | | - configuration = None |
632 | | - if turn_config: |
633 | | - ice = RTCIceServer( |
634 | | - urls=[turn_config.get("urls")], |
635 | | - username=turn_config.get("username"), |
636 | | - credential=turn_config.get("credential"), |
637 | | - ) |
638 | | - configuration = RTCConfiguration(iceServers=[ice]) |
639 | | - |
640 | | - pc = RTCPeerConnection(configuration=configuration) |
| 680 | + pc = RTCPeerConnection(configuration=turn_config) |
641 | 681 | relay = MediaRelay() |
642 | 682 |
|
643 | 683 | # Setup video receiver for frames from server |
@@ -812,9 +852,19 @@ def _on_data_message(message: Any) -> None: # noqa: ANN401 |
812 | 852 | "data_output": self._config.data_output, |
813 | 853 | } |
814 | 854 |
|
815 | | - # Add TURN config if available (auto-fetched or user-provided) |
| 855 | + # Add WebRTC config if available (auto-fetched or user-provided) |
| 856 | + # Server accepts webrtc_config with iceServers array format |
816 | 857 | if turn_config: |
817 | | - payload["webrtc_turn_config"] = turn_config |
| 858 | + payload["webrtc_config"] = { |
| 859 | + "iceServers": [ |
| 860 | + { |
| 861 | + "urls": ice_server.urls, |
| 862 | + "username": ice_server.username, |
| 863 | + "credential": ice_server.credential, |
| 864 | + } |
| 865 | + for ice_server in turn_config.iceServers |
| 866 | + ] |
| 867 | + } |
818 | 868 |
|
819 | 869 | # Add FPS if provided |
820 | 870 | if self._config.declared_fps: |
|
0 commit comments