Skip to content

Commit 5b349e5

Browse files
websockets: fix ping_timeout (#3376)
* websockets: fix ping_timeout * Closes #3258 * Closes #2905 * Closes #2655 * Fixes an issue with the calculation of ping timeout interval that could cause connections to be erroneously timed out and closed from the server end. * websocket: Fix lint, remove hard-coded 30s default timeout * websocket_test: Improve assertion error messages * websocket_test: Allow a little slack in ping timing Appears to be necessary on windows. --------- Co-authored-by: Ben Darnell <ben@bendarnell.com>
1 parent 5e4fff4 commit 5b349e5

File tree

2 files changed

+159
-50
lines changed

2 files changed

+159
-50
lines changed

tornado/test/websocket_test.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,11 @@ class PingHandler(TestWebSocketHandler):
810810
def on_pong(self, data):
811811
self.write_message("got pong")
812812

813-
return Application([("/", PingHandler)], websocket_ping_interval=0.01)
813+
return Application(
814+
[("/", PingHandler)],
815+
websocket_ping_interval=0.01,
816+
websocket_ping_timeout=0,
817+
)
814818

815819
@gen_test
816820
def test_server_ping(self):
@@ -831,14 +835,82 @@ def on_ping(self, data):
831835

832836
@gen_test
833837
def test_client_ping(self):
834-
ws = yield self.ws_connect("/", ping_interval=0.01)
838+
ws = yield self.ws_connect("/", ping_interval=0.01, ping_timeout=0)
835839
for i in range(3):
836840
response = yield ws.read_message()
837841
self.assertEqual(response, "got ping")
838-
# TODO: test that the connection gets closed if ping responses stop.
839842
ws.close()
840843

841844

845+
class ServerPingTimeoutTest(WebSocketBaseTestCase):
846+
def get_app(self):
847+
self.handlers: list[WebSocketHandler] = []
848+
test = self
849+
850+
class PingHandler(TestWebSocketHandler):
851+
def initialize(self, close_future=None, compression_options=None):
852+
self.handlers = test.handlers
853+
# capture the handler instance so we can interrogate it later
854+
self.handlers.append(self)
855+
return super().initialize(
856+
close_future=close_future, compression_options=compression_options
857+
)
858+
859+
app = Application([("/", PingHandler)])
860+
return app
861+
862+
@staticmethod
863+
def suppress_pong(ws):
864+
"""Suppress the client's "pong" response."""
865+
866+
def wrapper(fcn):
867+
def _inner(oppcode: int, data: bytes):
868+
if oppcode == 0xA: # NOTE: 0x9=ping, 0xA=pong
869+
# prevent pong responses
870+
return
871+
# leave all other responses unchanged
872+
return fcn(oppcode, data)
873+
874+
return _inner
875+
876+
ws.protocol._handle_message = wrapper(ws.protocol._handle_message)
877+
878+
@gen_test
879+
def test_client_ping_timeout(self):
880+
# websocket client
881+
interval = 0.2
882+
ws = yield self.ws_connect(
883+
"/", ping_interval=interval, ping_timeout=interval / 4
884+
)
885+
886+
# websocket handler (server side)
887+
handler = self.handlers[0]
888+
889+
for _ in range(5):
890+
# wait for the ping period
891+
yield gen.sleep(0.2)
892+
893+
# connection should still be open from the server end
894+
self.assertIsNone(handler.close_code)
895+
self.assertIsNone(handler.close_reason)
896+
897+
# connection should still be open from the client end
898+
assert ws.protocol.close_code is None
899+
900+
# suppress the pong response message
901+
self.suppress_pong(ws)
902+
903+
# give the server time to register this
904+
yield gen.sleep(interval * 1.5)
905+
906+
# connection should be closed from the server side
907+
self.assertEqual(handler.close_code, 1000)
908+
self.assertEqual(handler.close_reason, "ping timed out")
909+
910+
# client should have received a close operation
911+
self.assertEqual(ws.protocol.close_code, 1000)
912+
913+
842914
class ManualPingTest(WebSocketBaseTestCase):
843915
def get_app(self):
844916
class PingHandler(TestWebSocketHandler):

tornado/websocket.py

Lines changed: 84 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import abc
1515
import asyncio
1616
import base64
17+
import functools
1718
import hashlib
19+
import logging
1820
import os
1921
import sys
2022
import struct
@@ -26,7 +28,7 @@
2628
from tornado.concurrent import Future, future_set_result_unless_cancelled
2729
from tornado.escape import utf8, native_str, to_unicode
2830
from tornado import gen, httpclient, httputil
29-
from tornado.ioloop import IOLoop, PeriodicCallback
31+
from tornado.ioloop import IOLoop
3032
from tornado.iostream import StreamClosedError, IOStream
3133
from tornado.log import gen_log, app_log
3234
from tornado.netutil import Resolver
@@ -97,6 +99,9 @@ def log_exception(
9799

98100
_default_max_message_size = 10 * 1024 * 1024
99101

102+
# log to "gen_log" but suppress duplicate log messages
103+
de_dupe_gen_log = functools.lru_cache(gen_log.log)
104+
100105

101106
class WebSocketError(Exception):
102107
pass
@@ -274,17 +279,41 @@ async def get(self, *args: Any, **kwargs: Any) -> None:
274279

275280
@property
276281
def ping_interval(self) -> Optional[float]:
277-
"""The interval for websocket keep-alive pings.
282+
"""The interval for sending websocket pings.
283+
284+
If this is non-zero, the websocket will send a ping every
285+
ping_interval seconds.
286+
The client will respond with a "pong". The connection can be configured
287+
to timeout on late pong delivery using ``websocket_ping_timeout``.
278288
279-
Set websocket_ping_interval = 0 to disable pings.
289+
Set ``websocket_ping_interval = 0`` to disable pings.
290+
291+
Default: ``0``
280292
"""
281293
return self.settings.get("websocket_ping_interval", None)
282294

283295
@property
284296
def ping_timeout(self) -> Optional[float]:
285-
"""If no ping is received in this many seconds,
286-
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
287-
Default is max of 3 pings or 30 seconds.
297+
"""Timeout if no pong is received in this many seconds.
298+
299+
To be used in combination with ``websocket_ping_interval > 0``.
300+
If a ping response (a "pong") is not received within
301+
``websocket_ping_timeout`` seconds, then the websocket connection
302+
will be closed.
303+
304+
This can help to clean up clients which have disconnected without
305+
cleanly closing the websocket connection.
306+
307+
Note, the ping timeout cannot be longer than the ping interval.
308+
309+
Set ``websocket_ping_timeout = 0`` to disable the ping timeout.
310+
311+
Default: ``min(ping_interval, 30)``
312+
313+
.. versionchanged:: 6.5.0
314+
Default changed from the max of 3 pings or 30 seconds.
315+
The ping timeout can no longer be configured longer than the
316+
ping interval.
288317
"""
289318
return self.settings.get("websocket_ping_timeout", None)
290319

@@ -831,11 +860,10 @@ def __init__(
831860
# the effect of compression, frame overhead, and control frames.
832861
self._wire_bytes_in = 0
833862
self._wire_bytes_out = 0
834-
self.ping_callback = None # type: Optional[PeriodicCallback]
835-
self.last_ping = 0.0
836-
self.last_pong = 0.0
863+
self._received_pong = False # type: bool
837864
self.close_code = None # type: Optional[int]
838865
self.close_reason = None # type: Optional[str]
866+
self._ping_coroutine = None # type: Optional[asyncio.Task]
839867

840868
# Use a property for this to satisfy the abc.
841869
@property
@@ -1232,7 +1260,7 @@ def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]":
12321260
self._run_callback(self.handler.on_ping, data)
12331261
elif opcode == 0xA:
12341262
# Pong
1235-
self.last_pong = IOLoop.current().time()
1263+
self._received_pong = True
12361264
return self._run_callback(self.handler.on_pong, data)
12371265
else:
12381266
self._abort()
@@ -1266,9 +1294,9 @@ def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> Non
12661294
self._waiting = self.stream.io_loop.add_timeout(
12671295
self.stream.io_loop.time() + 5, self._abort
12681296
)
1269-
if self.ping_callback:
1270-
self.ping_callback.stop()
1271-
self.ping_callback = None
1297+
if self._ping_coroutine:
1298+
self._ping_coroutine.cancel()
1299+
self._ping_coroutine = None
12721300

12731301
def is_closing(self) -> bool:
12741302
"""Return ``True`` if this connection is closing.
@@ -1279,60 +1307,69 @@ def is_closing(self) -> bool:
12791307
"""
12801308
return self.stream.closed() or self.client_terminated or self.server_terminated
12811309

1310+
def set_nodelay(self, x: bool) -> None:
1311+
self.stream.set_nodelay(x)
1312+
12821313
@property
1283-
def ping_interval(self) -> Optional[float]:
1314+
def ping_interval(self) -> float:
12841315
interval = self.params.ping_interval
12851316
if interval is not None:
12861317
return interval
12871318
return 0
12881319

12891320
@property
1290-
def ping_timeout(self) -> Optional[float]:
1321+
def ping_timeout(self) -> float:
12911322
timeout = self.params.ping_timeout
12921323
if timeout is not None:
1324+
if self.ping_interval and timeout > self.ping_interval:
1325+
de_dupe_gen_log(
1326+
# Note: using de_dupe_gen_log to prevent this message from
1327+
# being duplicated for each connection
1328+
logging.WARNING,
1329+
f"The websocket_ping_timeout ({timeout}) cannot be longer"
1330+
f" than the websocket_ping_interval ({self.ping_interval})."
1331+
f"\nSetting websocket_ping_timeout={self.ping_interval}",
1332+
)
1333+
return self.ping_interval
12931334
return timeout
1294-
assert self.ping_interval is not None
1295-
return max(3 * self.ping_interval, 30)
1335+
return self.ping_interval
12961336

12971337
def start_pinging(self) -> None:
12981338
"""Start sending periodic pings to keep the connection alive"""
1299-
assert self.ping_interval is not None
1300-
if self.ping_interval > 0:
1301-
self.last_ping = self.last_pong = IOLoop.current().time()
1302-
self.ping_callback = PeriodicCallback(
1303-
self.periodic_ping, self.ping_interval * 1000
1304-
)
1305-
self.ping_callback.start()
1339+
if (
1340+
# prevent multiple ping coroutines being run in parallel
1341+
not self._ping_coroutine
1342+
# only run the ping coroutine if a ping interval is configured
1343+
and self.ping_interval > 0
1344+
):
1345+
self._ping_coroutine = asyncio.create_task(self.periodic_ping())
13061346

1307-
def periodic_ping(self) -> None:
1308-
"""Send a ping to keep the websocket alive
1347+
async def periodic_ping(self) -> None:
1348+
"""Send a ping and wait for a pong if ping_timeout is configured.
13091349
13101350
Called periodically if the websocket_ping_interval is set and non-zero.
13111351
"""
1312-
if self.is_closing() and self.ping_callback is not None:
1313-
self.ping_callback.stop()
1314-
return
1352+
interval = self.ping_interval
1353+
timeout = self.ping_timeout
13151354

1316-
# Check for timeout on pong. Make sure that we really have
1317-
# sent a recent ping in case the machine with both server and
1318-
# client has been suspended since the last ping.
1319-
now = IOLoop.current().time()
1320-
since_last_pong = now - self.last_pong
1321-
since_last_ping = now - self.last_ping
1322-
assert self.ping_interval is not None
1323-
assert self.ping_timeout is not None
1324-
if (
1325-
since_last_ping < 2 * self.ping_interval
1326-
and since_last_pong > self.ping_timeout
1327-
):
1328-
self.close()
1329-
return
1355+
await asyncio.sleep(interval)
13301356

1331-
self.write_ping(b"")
1332-
self.last_ping = now
1357+
while True:
1358+
# send a ping
1359+
self._received_pong = False
1360+
ping_time = IOLoop.current().time()
1361+
self.write_ping(b"")
13331362

1334-
def set_nodelay(self, x: bool) -> None:
1335-
self.stream.set_nodelay(x)
1363+
# wait until the ping timeout
1364+
await asyncio.sleep(timeout)
1365+
1366+
# make sure we received a pong within the timeout
1367+
if timeout > 0 and not self._received_pong:
1368+
self.close(reason="ping timed out")
1369+
return
1370+
1371+
# wait until the next scheduled ping
1372+
await asyncio.sleep(IOLoop.current().time() - ping_time + interval)
13361373

13371374

13381375
class WebSocketClientConnection(simple_httpclient._HTTPConnection):

0 commit comments

Comments
 (0)