diff --git a/src/robusta/core/model/env_vars.py b/src/robusta/core/model/env_vars.py index 8c45ff7a7..27da71633 100644 --- a/src/robusta/core/model/env_vars.py +++ b/src/robusta/core/model/env_vars.py @@ -82,6 +82,15 @@ def load_bool(env_var, default: bool): # Timeout for the ping response, before killing the connection. Must be smaller than the interval WEBSOCKET_PING_TIMEOUT = int(os.environ.get("WEBSOCKET_PING_TIMEOUT", 30)) +# TCP keepalive configuration (disabled by default) +WEBSOCKET_TCP_KEEPALIVE_ENABLED = os.environ.get("WEBSOCKET_TCP_KEEPALIVE_ENABLED", "false").lower() == "true" +# Time in seconds before sending the first keepalive probe (Linux: TCP_KEEPIDLE, macOS: TCP_KEEPALIVE) +WEBSOCKET_TCP_KEEPALIVE_IDLE = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_IDLE", 2)) +# Interval in seconds between keepalive probes +WEBSOCKET_TCP_KEEPALIVE_INTERVAL = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_INTERVAL", 2)) +# Number of failed probes before connection is considered dead +WEBSOCKET_TCP_KEEPALIVE_COUNT = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_COUNT", 5)) + TRACE_INCOMING_REQUESTS = load_bool("TRACE_INCOMING_REQUESTS", False) TRACE_INCOMING_ALERTS = load_bool("TRACE_INCOMING_ALERTS", False) diff --git a/src/robusta/integrations/receiver.py b/src/robusta/integrations/receiver.py index 6f674b18f..5d001f3c4 100644 --- a/src/robusta/integrations/receiver.py +++ b/src/robusta/integrations/receiver.py @@ -4,6 +4,8 @@ import json import logging import os +import socket +import sys import time from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext @@ -24,6 +26,10 @@ SENTRY_ENABLED, WEBSOCKET_PING_INTERVAL, WEBSOCKET_PING_TIMEOUT, + WEBSOCKET_TCP_KEEPALIVE_COUNT, + WEBSOCKET_TCP_KEEPALIVE_ENABLED, + WEBSOCKET_TCP_KEEPALIVE_IDLE, + WEBSOCKET_TCP_KEEPALIVE_INTERVAL, ) from robusta.core.playbooks.playbook_utils import to_safe_str from robusta.core.playbooks.playbooks_event_handler import PlaybooksEventHandler @@ -42,6 +48,22 @@ WEBSOCKET_THREADPOOL_SIZE = int(os.environ.get("WEBSOCKET_THREADPOOL_SIZE", 10)) +def _get_tcp_keepalive_options() -> tuple: + """Build TCP keepalive socket options tuple for run_forever(sockopt=...).""" + # TCP_KEEPIDLE is Linux-only; macOS uses TCP_KEEPALIVE (0x10) for the same purpose + if sys.platform == "darwin": + tcp_keepalive_idle = 0x10 + else: + tcp_keepalive_idle = socket.TCP_KEEPIDLE + + return ( + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), + (socket.IPPROTO_TCP, tcp_keepalive_idle, WEBSOCKET_TCP_KEEPALIVE_IDLE), + (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, WEBSOCKET_TCP_KEEPALIVE_INTERVAL), + (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, WEBSOCKET_TCP_KEEPALIVE_COUNT), + ) + + class ValidationResponse(BaseModel): http_code: int = 200 error_code: Optional[int] = None @@ -114,11 +136,27 @@ def start_receiver(self): def run_forever(self): logging.info("starting relay receiver") + sockopt = None + if WEBSOCKET_TCP_KEEPALIVE_ENABLED: + sockopt = _get_tcp_keepalive_options() + logging.info( + f"TCP keepalive enabled: idle={WEBSOCKET_TCP_KEEPALIVE_IDLE}s, " + f"interval={WEBSOCKET_TCP_KEEPALIVE_INTERVAL}s, count={WEBSOCKET_TCP_KEEPALIVE_COUNT}" + ) + if WEBSOCKET_PING_INTERVAL: + logging.info( + f"Websocket keepalive enabled: interval={WEBSOCKET_PING_INTERVAL}s, " + f"timeout={WEBSOCKET_PING_TIMEOUT}s" + ) while self.active: + # Handles WEBSOCKET_PING_INTERVAL == 0 + ping_timeout = WEBSOCKET_PING_TIMEOUT if WEBSOCKET_PING_INTERVAL else None + logging.info("relay websocket starting") self.ws.run_forever( ping_interval=WEBSOCKET_PING_INTERVAL, ping_payload="p", - ping_timeout=WEBSOCKET_PING_TIMEOUT, + ping_timeout=ping_timeout, + sockopt=sockopt, ) logging.info("relay websocket closed") time.sleep(INCOMING_WEBSOCKET_RECONNECT_DELAY_SEC) @@ -175,10 +213,14 @@ def __exec_external_request(self, action_request: ExternalActionRequest, validat if sync_response: http_code = 200 if response.get("success") else 500 + logging.debug( + f"Sending results for `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {http_code}") self.ws.send(data=json.dumps(self.__sync_response(http_code, action_request.request_id, response))) + logging.debug( + f"After Sending results for `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {http_code}") def __exec_external_stream_request(self, action_request: ExternalActionRequest, validate_timestamp: bool): - logging.debug(f"Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}") + logging.debug(f"Stream Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}") validation_response = self.__validate_request(action_request, validate_timestamp) if validation_response.http_code != 200: @@ -192,7 +234,11 @@ def __exec_external_stream_request(self, action_request: ExternalActionRequest, action_request.body.action_params, lambda data: self.__stream_response(request_id=action_request.request_id, data=data)) res = "" if res.get("success") else f"event: error\ndata: {json.dumps(res)}\n\n" + + logging.debug(f"Stream Sending result `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {res}") self.__close_stream_response(action_request.request_id, res) + logging.debug( + f"After Stream Sending result `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {res}") def _process_action(self, action: ExternalActionRequest, validate_timestamp: bool) -> None: self._executor.submit(self._process_action_sync, action, validate_timestamp) @@ -251,12 +297,18 @@ def on_message(self, ws: websocket.WebSocketApp, message: str) -> None: return if isinstance(incoming_event, SlackActionsMessage): + logging.debug( + f"on_message got Slack callback: {len(incoming_event.actions)} action(s) from " + f"user={incoming_event.user.username if incoming_event.user else 'unknown'}" + ) # slack callbacks have a list of 'actions'. Within each action there a 'value' field, # which container the actual action details we need to run. # This wrapper format is part of the slack API, and cannot be changed by us. for slack_action_request in incoming_event.actions: self._process_action(slack_action_request.value, validate_timestamp=False) else: + logging.debug( + f"on_message got external action request: `{incoming_event.body.action_name}` {to_safe_str(incoming_event.body.action_params)}") self._process_action(incoming_event, validate_timestamp=True) @staticmethod diff --git a/tests/test_slack.py b/tests/test_slack.py index 9c0818e73..521a5ce44 100644 --- a/tests/test_slack.py +++ b/tests/test_slack.py @@ -23,8 +23,8 @@ def test_send_to_slack(slack_channel: SlackChannel): finding = Finding(title=msg, aggregation_key=msg) finding.add_enrichment([MarkdownBlock("testing")]) slack_params = SlackSinkParams(name="test_slack", slack_channel=slack_channel.channel_name, api_key="") - slack_sender.send_finding_to_slack(finding, slack_params, False) - assert slack_channel.get_latest_message() == msg + ts = slack_sender.send_finding_to_slack(finding, slack_params, False) + assert slack_channel.get_message_by_ts(ts) == msg def test_long_slack_messages(slack_channel: SlackChannel): diff --git a/tests/utils/slack_utils.py b/tests/utils/slack_utils.py index 4890eeffb..040b89499 100644 --- a/tests/utils/slack_utils.py +++ b/tests/utils/slack_utils.py @@ -18,10 +18,23 @@ def was_message_sent_recently(self, expected) -> bool: return False def get_latest_message(self): + # Note: Prefer get_message_by_ts() to avoid race conditions when tests share a channel results = self.client.conversations_history(channel=self.channel_id) messages = results["messages"] return messages[0]["text"] + def get_message_by_ts(self, ts: str) -> str | None: + """Get message by timestamp - avoids race conditions unlike get_latest_message().""" + results = self.client.conversations_history( + channel=self.channel_id, + latest=ts, + oldest=ts, + inclusive=True, + limit=1 + ) + messages = results["messages"] + return messages[0]["text"] if messages else None + def get_complete_latest_message(self): results = self.client.conversations_history(channel=self.channel_id) messages = results["messages"]