Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/robusta/core/model/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def load_bool(env_var, default: bool):
# Timeout for the ping response, before killing the connection. Must be smaller than the interval
WEBSOCKET_PING_TIMEOUT = int(os.environ.get("WEBSOCKET_PING_TIMEOUT", 30))

# TCP keepalive configuration (disabled by default)
WEBSOCKET_TCP_KEEPALIVE_ENABLED = os.environ.get("WEBSOCKET_TCP_KEEPALIVE_ENABLED", "false").lower() == "true"
# Time in seconds before sending the first keepalive probe (Linux: TCP_KEEPIDLE, macOS: TCP_KEEPALIVE)
WEBSOCKET_TCP_KEEPALIVE_IDLE = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_IDLE", 2))
# Interval in seconds between keepalive probes
WEBSOCKET_TCP_KEEPALIVE_INTERVAL = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_INTERVAL", 2))
# Number of failed probes before connection is considered dead
WEBSOCKET_TCP_KEEPALIVE_COUNT = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_COUNT", 5))

TRACE_INCOMING_REQUESTS = load_bool("TRACE_INCOMING_REQUESTS", False)
TRACE_INCOMING_ALERTS = load_bool("TRACE_INCOMING_ALERTS", False)

Expand Down
56 changes: 54 additions & 2 deletions src/robusta/integrations/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import json
import logging
import os
import socket
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
Expand All @@ -24,6 +26,10 @@
SENTRY_ENABLED,
WEBSOCKET_PING_INTERVAL,
WEBSOCKET_PING_TIMEOUT,
WEBSOCKET_TCP_KEEPALIVE_COUNT,
WEBSOCKET_TCP_KEEPALIVE_ENABLED,
WEBSOCKET_TCP_KEEPALIVE_IDLE,
WEBSOCKET_TCP_KEEPALIVE_INTERVAL,
)
from robusta.core.playbooks.playbook_utils import to_safe_str
from robusta.core.playbooks.playbooks_event_handler import PlaybooksEventHandler
Expand All @@ -42,6 +48,22 @@
WEBSOCKET_THREADPOOL_SIZE = int(os.environ.get("WEBSOCKET_THREADPOOL_SIZE", 10))


def _get_tcp_keepalive_options() -> tuple:
"""Build TCP keepalive socket options tuple for run_forever(sockopt=...)."""
# TCP_KEEPIDLE is Linux-only; macOS uses TCP_KEEPALIVE (0x10) for the same purpose
if sys.platform == "darwin":
tcp_keepalive_idle = 0x10
else:
tcp_keepalive_idle = socket.TCP_KEEPIDLE

return (
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
(socket.IPPROTO_TCP, tcp_keepalive_idle, WEBSOCKET_TCP_KEEPALIVE_IDLE),
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, WEBSOCKET_TCP_KEEPALIVE_INTERVAL),
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, WEBSOCKET_TCP_KEEPALIVE_COUNT),
)


class ValidationResponse(BaseModel):
http_code: int = 200
error_code: Optional[int] = None
Expand Down Expand Up @@ -114,11 +136,27 @@ def start_receiver(self):

def run_forever(self):
logging.info("starting relay receiver")
sockopt = None
if WEBSOCKET_TCP_KEEPALIVE_ENABLED:
sockopt = _get_tcp_keepalive_options()
logging.info(
f"TCP keepalive enabled: idle={WEBSOCKET_TCP_KEEPALIVE_IDLE}s, "
f"interval={WEBSOCKET_TCP_KEEPALIVE_INTERVAL}s, count={WEBSOCKET_TCP_KEEPALIVE_COUNT}"
)
if WEBSOCKET_PING_INTERVAL:
logging.info(
f"Websocket keepalive enabled: interval={WEBSOCKET_PING_INTERVAL}s, "
f"timeout={WEBSOCKET_PING_TIMEOUT}s"
)
while self.active:
# Handles WEBSOCKET_PING_INTERVAL == 0
ping_timeout = WEBSOCKET_PING_TIMEOUT if WEBSOCKET_PING_INTERVAL else None
logging.info("relay websocket starting")
self.ws.run_forever(
ping_interval=WEBSOCKET_PING_INTERVAL,
ping_payload="p",
ping_timeout=WEBSOCKET_PING_TIMEOUT,
ping_timeout=ping_timeout,
sockopt=sockopt,
)
logging.info("relay websocket closed")
time.sleep(INCOMING_WEBSOCKET_RECONNECT_DELAY_SEC)
Expand Down Expand Up @@ -175,10 +213,14 @@ def __exec_external_request(self, action_request: ExternalActionRequest, validat

if sync_response:
http_code = 200 if response.get("success") else 500
logging.debug(
f"Sending results for `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {http_code}")
self.ws.send(data=json.dumps(self.__sync_response(http_code, action_request.request_id, response)))
logging.debug(
f"After Sending results for `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {http_code}")

def __exec_external_stream_request(self, action_request: ExternalActionRequest, validate_timestamp: bool):
logging.debug(f"Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}")
logging.debug(f"Stream Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}")

validation_response = self.__validate_request(action_request, validate_timestamp)
if validation_response.http_code != 200:
Expand All @@ -192,7 +234,11 @@ def __exec_external_stream_request(self, action_request: ExternalActionRequest,
action_request.body.action_params,
lambda data: self.__stream_response(request_id=action_request.request_id, data=data))
res = "" if res.get("success") else f"event: error\ndata: {json.dumps(res)}\n\n"

logging.debug(f"Stream Sending result `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {res}")
self.__close_stream_response(action_request.request_id, res)
logging.debug(
f"After Stream Sending result `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {res}")

def _process_action(self, action: ExternalActionRequest, validate_timestamp: bool) -> None:
self._executor.submit(self._process_action_sync, action, validate_timestamp)
Expand Down Expand Up @@ -251,12 +297,18 @@ def on_message(self, ws: websocket.WebSocketApp, message: str) -> None:
return

if isinstance(incoming_event, SlackActionsMessage):
logging.debug(
f"on_message got Slack callback: {len(incoming_event.actions)} action(s) from "
f"user={incoming_event.user.username if incoming_event.user else 'unknown'}"
)
# slack callbacks have a list of 'actions'. Within each action there a 'value' field,
# which container the actual action details we need to run.
# This wrapper format is part of the slack API, and cannot be changed by us.
for slack_action_request in incoming_event.actions:
self._process_action(slack_action_request.value, validate_timestamp=False)
else:
logging.debug(
f"on_message got external action request: `{incoming_event.body.action_name}` {to_safe_str(incoming_event.body.action_params)}")
self._process_action(incoming_event, validate_timestamp=True)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions tests/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def test_send_to_slack(slack_channel: SlackChannel):
finding = Finding(title=msg, aggregation_key=msg)
finding.add_enrichment([MarkdownBlock("testing")])
slack_params = SlackSinkParams(name="test_slack", slack_channel=slack_channel.channel_name, api_key="")
slack_sender.send_finding_to_slack(finding, slack_params, False)
assert slack_channel.get_latest_message() == msg
ts = slack_sender.send_finding_to_slack(finding, slack_params, False)
assert slack_channel.get_message_by_ts(ts) == msg


def test_long_slack_messages(slack_channel: SlackChannel):
Expand Down
13 changes: 13 additions & 0 deletions tests/utils/slack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,23 @@ def was_message_sent_recently(self, expected) -> bool:
return False

def get_latest_message(self):
# Note: Prefer get_message_by_ts() to avoid race conditions when tests share a channel
results = self.client.conversations_history(channel=self.channel_id)
messages = results["messages"]
return messages[0]["text"]

def get_message_by_ts(self, ts: str) -> str | None:
"""Get message by timestamp - avoids race conditions unlike get_latest_message()."""
results = self.client.conversations_history(
channel=self.channel_id,
latest=ts,
oldest=ts,
inclusive=True,
limit=1
)
messages = results["messages"]
return messages[0]["text"] if messages else None

def get_complete_latest_message(self):
results = self.client.conversations_history(channel=self.channel_id)
messages = results["messages"]
Expand Down
Loading