Skip to content

Commit b6b9527

Browse files
fix #1446: monitor dies when exceptions raised before monitor created. (#1447)
--------- Co-authored-by: Kazuhiro Sera <[email protected]>
1 parent df47c32 commit b6b9527

File tree

7 files changed

+232
-45
lines changed

7 files changed

+232
-45
lines changed

slack_sdk/socket_mode/aiohttp/__init__.py

Lines changed: 61 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -342,48 +342,68 @@ async def session_id(self) -> str:
342342
return self.build_session_id(self.current_session)
343343

344344
async def connect(self):
345-
old_session: Optional[ClientWebSocketResponse] = None if self.current_session is None else self.current_session
346-
if self.wss_uri is None:
347-
# If the underlying WSS URL does not exist,
348-
# acquiring a new active WSS URL from the server-side first
349-
self.wss_uri = await self.issue_new_wss_url()
350-
351-
self.current_session = await self.aiohttp_client_session.ws_connect(
352-
self.wss_uri,
353-
autoping=False,
354-
heartbeat=self.ping_interval,
355-
proxy=self.proxy,
356-
ssl=self.web_client.ssl,
357-
)
358-
session_id: str = await self.session_id()
359-
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
360-
self.stale = False
361-
self.logger.info(f"A new session ({session_id}) has been established")
362-
363-
# The first ping from the new connection
364-
if self.logger.level <= logging.DEBUG:
365-
self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...")
366-
t = time.time()
367-
await self.current_session.ping(f"sdk-ping-pong:{t}")
368-
369-
if self.current_session_monitor is not None:
370-
self.current_session_monitor.cancel()
371-
372-
self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session())
373-
if self.logger.level <= logging.DEBUG:
374-
self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}")
375-
376-
if self.message_receiver is not None:
377-
self.message_receiver.cancel()
378-
379-
self.message_receiver = asyncio.ensure_future(self.receive_messages())
380-
if self.logger.level <= logging.DEBUG:
381-
self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}")
345+
# This loop is used to ensure when a new session is created,
346+
# a new monitor and a new message receiver are also created.
347+
# If a new session is created but we failed to create the new
348+
# monitor or the new message, we should try it.
349+
while True:
350+
try:
351+
old_session: Optional[ClientWebSocketResponse] = (
352+
None if self.current_session is None else self.current_session
353+
)
382354

383-
if old_session is not None:
384-
await old_session.close()
385-
old_session_id = self.build_session_id(old_session)
386-
self.logger.info(f"The old session ({old_session_id}) has been abandoned")
355+
# If the old session is broken (e.g. reset by peer), it might fail to close it.
356+
# We don't want to retry when this kind of cases happen.
357+
try:
358+
# We should close old session before create a new one. Because when disconnect
359+
# reason is `too_many_websockets`, we need to close the old one first to
360+
# to decrease the number of connections.
361+
self.auto_reconnect_enabled = False
362+
if old_session is not None:
363+
await old_session.close()
364+
old_session_id = self.build_session_id(old_session)
365+
self.logger.info(f"The old session ({old_session_id}) has been abandoned")
366+
except Exception as e:
367+
self.logger.exception(f"Failed to close the old session : {e}")
368+
369+
if self.wss_uri is None:
370+
# If the underlying WSS URL does not exist,
371+
# acquiring a new active WSS URL from the server-side first
372+
self.wss_uri = await self.issue_new_wss_url()
373+
374+
self.current_session = await self.aiohttp_client_session.ws_connect(
375+
self.wss_uri,
376+
autoping=False,
377+
heartbeat=self.ping_interval,
378+
proxy=self.proxy,
379+
ssl=self.web_client.ssl,
380+
)
381+
session_id: str = await self.session_id()
382+
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
383+
self.stale = False
384+
self.logger.info(f"A new session ({session_id}) has been established")
385+
386+
# The first ping from the new connection
387+
if self.logger.level <= logging.DEBUG:
388+
self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...")
389+
t = time.time()
390+
await self.current_session.ping(f"sdk-ping-pong:{t}")
391+
392+
if self.current_session_monitor is not None:
393+
self.current_session_monitor.cancel()
394+
self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session())
395+
if self.logger.level <= logging.DEBUG:
396+
self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}")
397+
398+
if self.message_receiver is not None:
399+
self.message_receiver.cancel()
400+
self.message_receiver = asyncio.ensure_future(self.receive_messages())
401+
if self.logger.level <= logging.DEBUG:
402+
self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}")
403+
break
404+
except Exception as e:
405+
self.logger.exception(f"Failed to connect (error: {e}); Retrying...")
406+
await asyncio.sleep(self.ping_interval)
387407

388408
async def disconnect(self):
389409
if self.current_session is not None:

tests/slack_sdk/socket_mode/mock_socket_mode_server.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
33
import os
4+
import time
45

56
from aiohttp import WSMsgType, web
67

@@ -24,6 +25,8 @@
2425

2526
socket_mode_hello_message = """{"type":"hello","num_connections":2,"debug_info":{"host":"applink-111-xxx","build_number":10,"approximate_connection_time":18060},"connection_info":{"app_id":"A111"}}"""
2627

28+
socket_mode_disconnect_message = """{"type":"disconnect","reason":"too_many_websockets","num_connections":2,"debug_info":{"host":"applink-111-xxx"},"connection_info":{"app_id":"A111"}}"""
29+
2730

2831
def start_socket_mode_server(self, port: int):
2932
logger = logging.getLogger(__name__)
@@ -82,3 +85,77 @@ def run_server():
8285
loop.close()
8386

8487
return run_server
88+
89+
90+
def start_socket_mode_server_with_disconnection(self, port: int):
91+
logger = logging.getLogger(__name__)
92+
state = {}
93+
94+
def reset_server_state():
95+
state.update(
96+
hello_sent=False,
97+
disconnect_sent=False,
98+
envelopes_to_consume=list(socket_mode_envelopes),
99+
)
100+
101+
self.reset_server_state = reset_server_state
102+
103+
async def link(request):
104+
disconnected = False
105+
ws = web.WebSocketResponse()
106+
await ws.prepare(request)
107+
108+
async for msg in ws:
109+
# To ensure disconnect message is received and handled,
110+
# need to keep this ws alive to bypass client ping-pong check.
111+
if msg.type == WSMsgType.PING:
112+
t = time.time()
113+
await ws.pong(f"sdk-ping-pong:{t}")
114+
continue
115+
if msg.type != WSMsgType.TEXT:
116+
continue
117+
118+
message = msg.data
119+
logger.debug(f"Server received a message: {message}")
120+
121+
if not state["hello_sent"]:
122+
state["hello_sent"] = True
123+
await ws.send_str(socket_mode_hello_message)
124+
125+
if not state["disconnect_sent"]:
126+
state["hello_sent"] = False
127+
state["disconnect_sent"] = True
128+
disconnected = True
129+
await ws.send_str(socket_mode_disconnect_message)
130+
logger.debug(f"Disconnect message sent")
131+
132+
if state["envelopes_to_consume"] and not disconnected:
133+
e = state["envelopes_to_consume"].pop(0)
134+
logger.debug(f"Send an envelope: {e}")
135+
await ws.send_str(e)
136+
137+
await ws.send_str(message)
138+
139+
return ws
140+
141+
app = web.Application()
142+
app.add_routes([web.get("/link", link)])
143+
runner = web.AppRunner(app)
144+
145+
def run_server():
146+
reset_server_state()
147+
148+
self.loop = loop = asyncio.new_event_loop()
149+
asyncio.set_event_loop(loop)
150+
loop.run_until_complete(runner.setup())
151+
site = web.TCPSite(runner, "127.0.0.1", port, reuse_port=True)
152+
loop.run_until_complete(site.start())
153+
154+
# run until it's stopped from the main thread
155+
loop.run_forever()
156+
157+
loop.run_until_complete(runner.cleanup())
158+
loop.run_until_complete(asyncio.sleep(1))
159+
loop.close()
160+
161+
return run_server

tests/slack_sdk/socket_mode/mock_web_api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _handle(self):
8989
if self.path == "/apps.connections.open":
9090
body = {
9191
"ok": True,
92-
"url": "wss://test-server/link/?ticket=xxx&app_id=yyy",
92+
"url": "ws://0.0.0.0:3001/link",
9393
}
9494
if self.path == "/api.test" and request_body:
9595
body = {"ok": True, "args": request_body}

tests/slack_sdk/socket_mode/test_builtin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_issue_new_wss_url(self):
5656
web_client=self.web_client,
5757
)
5858
url = client.issue_new_wss_url()
59-
self.assertTrue(url.startswith("wss://"))
59+
self.assertTrue(url.startswith("ws://"))
6060

6161
legacy_client = LegacyWebClient(token="xoxb-api_test", base_url="http://localhost:8888")
6262
response = legacy_client.apps_connections_open(app_token="xapp-A111-222-xyz")

tests/slack_sdk_async/socket_mode/test_aiohttp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def test_issue_new_wss_url(self):
4242
)
4343
try:
4444
url = await client.issue_new_wss_url()
45-
self.assertTrue(url.startswith("wss://"))
45+
self.assertTrue(url.startswith("ws://"))
4646
finally:
4747
await client.close()
4848

tests/slack_sdk_async/socket_mode/test_interactions_aiohttp.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from tests.helpers import is_ci_unstable_test_skip_enabled
1818
from tests.slack_sdk.socket_mode.mock_socket_mode_server import (
1919
start_socket_mode_server,
20+
start_socket_mode_server_with_disconnection,
2021
socket_mode_envelopes,
2122
socket_mode_hello_message,
23+
socket_mode_disconnect_message,
2224
)
2325
from tests.slack_sdk.socket_mode.mock_web_api_server import (
2426
setup_mock_web_api_server,
@@ -104,6 +106,94 @@ async def socket_mode_listener(
104106
self.loop.stop()
105107
t.join(timeout=5)
106108

109+
@async_test
110+
async def test_interactions_with_disconnection(self):
111+
if is_ci_unstable_test_skip_enabled():
112+
return
113+
t = Thread(target=start_socket_mode_server_with_disconnection(self, 3001))
114+
t.daemon = True
115+
t.start()
116+
117+
self.disconnected = False
118+
received_messages = []
119+
received_socket_mode_requests = []
120+
121+
async def message_handler(message: WSMessage):
122+
session_id = client.build_session_id(client.current_session)
123+
if "wait_for_disconnect" in message.data:
124+
return
125+
self.logger.info(f"Raw Message: {message}")
126+
await asyncio.sleep(randint(50, 200) / 1000)
127+
self.disconnected = "disconnect" in message.data
128+
received_messages.append(message.data + "_" + session_id)
129+
130+
async def socket_mode_listener(
131+
self: AsyncBaseSocketModeClient,
132+
request: SocketModeRequest,
133+
):
134+
self.logger.info(f"Socket Mode Request: {request.payload}")
135+
await asyncio.sleep(randint(50, 200) / 1000)
136+
received_socket_mode_requests.append(request.payload)
137+
138+
client = SocketModeClient(
139+
app_token="xapp-A111-222-xyz",
140+
web_client=self.web_client,
141+
on_message_listeners=[message_handler],
142+
auto_reconnect_enabled=True,
143+
trace_enabled=True,
144+
)
145+
client.socket_mode_request_listeners.append(socket_mode_listener)
146+
147+
try:
148+
time.sleep(1) # wait for the server
149+
client.wss_uri = "ws://0.0.0.0:3001/link"
150+
await client.connect()
151+
await asyncio.sleep(1) # wait for the message receiver
152+
153+
# Because we want to check the expected messages of new session,
154+
# we need to ensure we send messaged after disconnected.
155+
count = 0
156+
while not self.disconnected and count < 10:
157+
try:
158+
await client.send_message("wait_for_disconnect")
159+
except Exception as e:
160+
self.logger.exception(e)
161+
finally:
162+
await asyncio.sleep(1)
163+
count += 1
164+
await asyncio.sleep(10)
165+
expected_session_id = client.build_session_id(client.current_session)
166+
167+
for _ in range(10):
168+
await client.send_message("foo")
169+
await client.send_message("bar")
170+
await client.send_message("baz")
171+
172+
expected = socket_mode_envelopes + [socket_mode_hello_message] + ["foo", "bar", "baz"] * 10
173+
expected.sort()
174+
175+
count = 0
176+
while count < 10 and (
177+
len([msg for msg in received_messages if expected_session_id in msg]) < len(expected)
178+
or len(received_socket_mode_requests) < len(socket_mode_envelopes)
179+
):
180+
await asyncio.sleep(0.2)
181+
count += 0.2
182+
183+
received_messages.sort()
184+
185+
# Only check messages of current alive session. Ignore the disconnected session.
186+
received_messages = [msg for msg in received_messages if expected_session_id in msg]
187+
expected = [msg + "_" + expected_session_id for msg in expected]
188+
189+
self.assertEqual(received_messages, expected)
190+
191+
self.assertEqual(len(socket_mode_envelopes), len(received_socket_mode_requests))
192+
finally:
193+
await client.close()
194+
self.loop.stop()
195+
t.join(timeout=5)
196+
107197
@async_test
108198
async def test_send_message_while_disconnection(self):
109199
if is_ci_unstable_test_skip_enabled():

tests/slack_sdk_async/socket_mode/test_websockets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def test_issue_new_wss_url(self):
3636
)
3737
try:
3838
url = await client.issue_new_wss_url()
39-
self.assertTrue(url.startswith("wss://"))
39+
self.assertTrue(url.startswith("ws://"))
4040
finally:
4141
await client.close()
4242

0 commit comments

Comments
 (0)