Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/robusta/core/model/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def load_bool(env_var, default: bool):
# Timeout for the ping response, before killing the connection. Must be smaller than the interval
WEBSOCKET_PING_TIMEOUT = int(os.environ.get("WEBSOCKET_PING_TIMEOUT", 30))

# TCP keepalive configuration (disabled by default)
WEBSOCKET_TCP_KEEPALIVE_ENABLED = os.environ.get("WEBSOCKET_TCP_KEEPALIVE_ENABLED", "false").lower() == "true"
# Time in seconds before sending the first keepalive probe (Linux: TCP_KEEPIDLE, macOS: TCP_KEEPALIVE)
WEBSOCKET_TCP_KEEPALIVE_IDLE = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_IDLE", 2))
# Interval in seconds between keepalive probes
WEBSOCKET_TCP_KEEPALIVE_INTERVAL = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_INTERVAL", 2))
# Number of failed probes before connection is considered dead
WEBSOCKET_TCP_KEEPALIVE_COUNT = int(os.environ.get("WEBSOCKET_TCP_KEEPALIVE_COUNT", 5))

TRACE_INCOMING_REQUESTS = load_bool("TRACE_INCOMING_REQUESTS", False)
TRACE_INCOMING_ALERTS = load_bool("TRACE_INCOMING_ALERTS", False)

Expand Down
56 changes: 54 additions & 2 deletions src/robusta/integrations/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import json
import logging
import os
import socket
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
Expand All @@ -24,6 +26,10 @@
SENTRY_ENABLED,
WEBSOCKET_PING_INTERVAL,
WEBSOCKET_PING_TIMEOUT,
WEBSOCKET_TCP_KEEPALIVE_COUNT,
WEBSOCKET_TCP_KEEPALIVE_ENABLED,
WEBSOCKET_TCP_KEEPALIVE_IDLE,
WEBSOCKET_TCP_KEEPALIVE_INTERVAL,
)
from robusta.core.playbooks.playbook_utils import to_safe_str
from robusta.core.playbooks.playbooks_event_handler import PlaybooksEventHandler
Expand All @@ -42,6 +48,22 @@
WEBSOCKET_THREADPOOL_SIZE = int(os.environ.get("WEBSOCKET_THREADPOOL_SIZE", 10))


def _get_tcp_keepalive_options() -> tuple:
"""Build TCP keepalive socket options tuple for run_forever(sockopt=...)."""
# TCP_KEEPIDLE is Linux-only; macOS uses TCP_KEEPALIVE (0x10) for the same purpose
if sys.platform == "darwin":
tcp_keepalive_idle = 0x10
else:
tcp_keepalive_idle = socket.TCP_KEEPIDLE

return (
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
(socket.IPPROTO_TCP, tcp_keepalive_idle, WEBSOCKET_TCP_KEEPALIVE_IDLE),
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, WEBSOCKET_TCP_KEEPALIVE_INTERVAL),
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, WEBSOCKET_TCP_KEEPALIVE_COUNT),
)


class ValidationResponse(BaseModel):
http_code: int = 200
error_code: Optional[int] = None
Expand Down Expand Up @@ -114,11 +136,27 @@ def start_receiver(self):

def run_forever(self):
logging.info("starting relay receiver")
sockopt = None
if WEBSOCKET_TCP_KEEPALIVE_ENABLED:
sockopt = _get_tcp_keepalive_options()
logging.info(
f"TCP keepalive enabled: idle={WEBSOCKET_TCP_KEEPALIVE_IDLE}s, "
f"interval={WEBSOCKET_TCP_KEEPALIVE_INTERVAL}s, count={WEBSOCKET_TCP_KEEPALIVE_COUNT}"
)
if WEBSOCKET_PING_INTERVAL:
logging.info(
f"Websocket keepalive enabled: interval={WEBSOCKET_PING_INTERVAL}s, "
f"timeout={WEBSOCKET_PING_TIMEOUT}s"
)
while self.active:
# Handles WEBSOCKET_PING_INTERVAL == 0
ping_timeout = WEBSOCKET_PING_TIMEOUT if WEBSOCKET_PING_INTERVAL else None
logging.info("relay websocket starting")
self.ws.run_forever(
ping_interval=WEBSOCKET_PING_INTERVAL,
ping_payload="p",
ping_timeout=WEBSOCKET_PING_TIMEOUT,
ping_timeout=ping_timeout,
sockopt=sockopt,
)
logging.info("relay websocket closed")
time.sleep(INCOMING_WEBSOCKET_RECONNECT_DELAY_SEC)
Expand Down Expand Up @@ -175,10 +213,14 @@ def __exec_external_request(self, action_request: ExternalActionRequest, validat

if sync_response:
http_code = 200 if response.get("success") else 500
logging.debug(
f"Sending results for `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {http_code}")
self.ws.send(data=json.dumps(self.__sync_response(http_code, action_request.request_id, response)))
logging.debug(
f"After Sending results for `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {http_code}")

def __exec_external_stream_request(self, action_request: ExternalActionRequest, validate_timestamp: bool):
logging.debug(f"Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}")
logging.debug(f"Stream Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}")

validation_response = self.__validate_request(action_request, validate_timestamp)
if validation_response.http_code != 200:
Expand All @@ -192,7 +234,11 @@ def __exec_external_stream_request(self, action_request: ExternalActionRequest,
action_request.body.action_params,
lambda data: self.__stream_response(request_id=action_request.request_id, data=data))
res = "" if res.get("success") else f"event: error\ndata: {json.dumps(res)}\n\n"

logging.debug(f"Stream Sending result `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {res}")
self.__close_stream_response(action_request.request_id, res)
logging.debug(
f"After Stream Sending result `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)} - {res}")

def _process_action(self, action: ExternalActionRequest, validate_timestamp: bool) -> None:
self._executor.submit(self._process_action_sync, action, validate_timestamp)
Expand Down Expand Up @@ -251,12 +297,18 @@ def on_message(self, ws: websocket.WebSocketApp, message: str) -> None:
return

if isinstance(incoming_event, SlackActionsMessage):
logging.debug(
f"on_message got Slack callback: {len(incoming_event.actions)} action(s) from "
f"user={incoming_event.user.username if incoming_event.user else 'unknown'}"
)
# slack callbacks have a list of 'actions'. Within each action there a 'value' field,
# which container the actual action details we need to run.
# This wrapper format is part of the slack API, and cannot be changed by us.
for slack_action_request in incoming_event.actions:
self._process_action(slack_action_request.value, validate_timestamp=False)
else:
logging.debug(
f"on_message got external action request: `{incoming_event.body.action_name}` {to_safe_str(incoming_event.body.action_params)}")
self._process_action(incoming_event, validate_timestamp=True)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions tests/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def test_send_to_slack(slack_channel: SlackChannel):
finding = Finding(title=msg, aggregation_key=msg)
finding.add_enrichment([MarkdownBlock("testing")])
slack_params = SlackSinkParams(name="test_slack", slack_channel=slack_channel.channel_name, api_key="")
slack_sender.send_finding_to_slack(finding, slack_params, False)
assert slack_channel.get_latest_message() == msg
ts = slack_sender.send_finding_to_slack(finding, slack_params, False)
assert slack_channel.get_message_by_ts(ts) == msg


def test_long_slack_messages(slack_channel: SlackChannel):
Expand Down
13 changes: 13 additions & 0 deletions tests/utils/slack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,23 @@ def was_message_sent_recently(self, expected) -> bool:
return False

def get_latest_message(self):
# Note: Prefer get_message_by_ts() to avoid race conditions when tests share a channel
results = self.client.conversations_history(channel=self.channel_id)
messages = results["messages"]
return messages[0]["text"]

def get_message_by_ts(self, ts: str) -> str:
"""Get message by timestamp - avoids race conditions unlike get_latest_message()."""
results = self.client.conversations_history(
channel=self.channel_id,
latest=ts,
oldest=ts,
inclusive=True,
limit=1
)
messages = results["messages"]
return messages[0]["text"] if messages else None

def get_complete_latest_message(self):
results = self.client.conversations_history(channel=self.channel_id)
messages = results["messages"]
Expand Down
Loading