Skip to content

Commit 6f45fb9

Browse files
Copilotmykaul
andcommitted
Implement TLS session cache for faster reconnections
Co-authored-by: mykaul <[email protected]>
1 parent 80cee85 commit 6f45fb9

File tree

2 files changed

+177
-2
lines changed

2 files changed

+177
-2
lines changed

cassandra/cluster.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,37 @@ def default_retry_policy(self, policy):
875875
.. versionadded:: 3.17.0
876876
"""
877877

878+
tls_session_cache_enabled = True
879+
"""
880+
Enable TLS session caching for faster reconnections. When enabled, TLS sessions
881+
are cached per endpoint and reused for subsequent connections to the same server.
882+
This reduces handshake latency and CPU usage during reconnections.
883+
884+
Defaults to True when SSL/TLS is enabled. Set to False to disable session caching.
885+
886+
.. versionadded:: 3.30.0
887+
"""
888+
889+
tls_session_cache_size = 100
890+
"""
891+
Maximum number of TLS sessions to cache. When the cache is full, the least
892+
recently used session is evicted.
893+
894+
Defaults to 100.
895+
896+
.. versionadded:: 3.30.0
897+
"""
898+
899+
tls_session_cache_ttl = 3600
900+
"""
901+
Time-to-live for cached TLS sessions in seconds. Sessions older than this
902+
are not reused and are removed from the cache.
903+
904+
Defaults to 3600 seconds (1 hour).
905+
906+
.. versionadded:: 3.30.0
907+
"""
908+
878909
sockopts = None
879910
"""
880911
An optional list of tuples which will be used as arguments to
@@ -1204,6 +1235,9 @@ def __init__(self,
12041235
idle_heartbeat_timeout=30,
12051236
no_compact=False,
12061237
ssl_context=None,
1238+
tls_session_cache_enabled=True,
1239+
tls_session_cache_size=100,
1240+
tls_session_cache_ttl=3600,
12071241
endpoint_factory=None,
12081242
application_name=None,
12091243
application_version=None,
@@ -1420,6 +1454,19 @@ def __init__(self,
14201454

14211455
self.ssl_options = ssl_options
14221456
self.ssl_context = ssl_context
1457+
self.tls_session_cache_enabled = tls_session_cache_enabled
1458+
self.tls_session_cache_size = tls_session_cache_size
1459+
self.tls_session_cache_ttl = tls_session_cache_ttl
1460+
1461+
# Initialize TLS session cache if SSL is enabled
1462+
self._tls_session_cache = None
1463+
if (ssl_context or ssl_options) and tls_session_cache_enabled:
1464+
from cassandra.connection import TLSSessionCache
1465+
self._tls_session_cache = TLSSessionCache(
1466+
max_size=tls_session_cache_size,
1467+
ttl=tls_session_cache_ttl
1468+
)
1469+
14231470
self.sockopts = sockopts
14241471
self.cql_version = cql_version
14251472
self.max_schema_agreement_wait = max_schema_agreement_wait
@@ -1661,6 +1708,7 @@ def _make_connection_kwargs(self, endpoint, kwargs_dict):
16611708
kwargs_dict.setdefault('sockopts', self.sockopts)
16621709
kwargs_dict.setdefault('ssl_options', self.ssl_options)
16631710
kwargs_dict.setdefault('ssl_context', self.ssl_context)
1711+
kwargs_dict.setdefault('tls_session_cache', self._tls_session_cache)
16641712
kwargs_dict.setdefault('cql_version', self.cql_version)
16651713
kwargs_dict.setdefault('protocol_version', self.protocol_version)
16661714
kwargs_dict.setdefault('user_type_map', self._user_types)

cassandra/connection.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,108 @@ def decompress(byts):
128128
frame_header_v3 = struct.Struct('>BhBi')
129129

130130

131+
class TLSSessionCache:
132+
"""
133+
Thread-safe cache for TLS sessions to enable session resumption.
134+
135+
This cache stores TLS sessions per endpoint (host:port) to allow
136+
quick TLS renegotiation when reconnecting to the same server.
137+
Sessions are automatically expired after a TTL and the cache has
138+
a maximum size with LRU eviction.
139+
"""
140+
141+
def __init__(self, max_size=100, ttl=3600):
142+
"""
143+
Initialize the TLS session cache.
144+
145+
Args:
146+
max_size: Maximum number of sessions to cache (default: 100)
147+
ttl: Time-to-live for cached sessions in seconds (default: 3600)
148+
"""
149+
self._sessions = {} # {endpoint_key: (session, timestamp, access_time)}
150+
self._lock = RLock()
151+
self._max_size = max_size
152+
self._ttl = ttl
153+
154+
def _make_key(self, host, port):
155+
"""Create a cache key from host and port."""
156+
return (host, port)
157+
158+
def get_session(self, host, port):
159+
"""
160+
Get a cached TLS session for the given endpoint.
161+
162+
Args:
163+
host: The hostname or IP address
164+
port: The port number
165+
166+
Returns:
167+
ssl.SSLSession object if a valid cached session exists, None otherwise
168+
"""
169+
key = self._make_key(host, port)
170+
with self._lock:
171+
if key not in self._sessions:
172+
return None
173+
174+
session, timestamp, _ = self._sessions[key]
175+
176+
# Check if session has expired
177+
if time.time() - timestamp > self._ttl:
178+
del self._sessions[key]
179+
return None
180+
181+
# Update access time for LRU
182+
self._sessions[key] = (session, timestamp, time.time())
183+
return session
184+
185+
def set_session(self, host, port, session):
186+
"""
187+
Store a TLS session for the given endpoint.
188+
189+
Args:
190+
host: The hostname or IP address
191+
port: The port number
192+
session: The ssl.SSLSession object to cache
193+
"""
194+
if session is None:
195+
return
196+
197+
key = self._make_key(host, port)
198+
current_time = time.time()
199+
200+
with self._lock:
201+
# If cache is at max size, remove least recently used entry
202+
if len(self._sessions) >= self._max_size and key not in self._sessions:
203+
# Find entry with oldest access time
204+
oldest_key = min(self._sessions.keys(),
205+
key=lambda k: self._sessions[k][2])
206+
del self._sessions[oldest_key]
207+
208+
# Store session with creation time and access time
209+
self._sessions[key] = (session, current_time, current_time)
210+
211+
def clear_expired(self):
212+
"""Remove all expired sessions from the cache."""
213+
current_time = time.time()
214+
with self._lock:
215+
expired_keys = [
216+
key for key, (_, timestamp, _) in self._sessions.items()
217+
if current_time - timestamp > self._ttl
218+
]
219+
for key in expired_keys:
220+
del self._sessions[key]
221+
222+
def clear(self):
223+
"""Clear all sessions from the cache."""
224+
with self._lock:
225+
self._sessions.clear()
226+
227+
def size(self):
228+
"""Return the current number of cached sessions."""
229+
with self._lock:
230+
return len(self._sessions)
231+
232+
131233
class EndPoint(object):
132234
"""
133235
Represents the information to connect to a cassandra node.
@@ -687,6 +789,8 @@ class Connection(object):
687789
endpoint = None
688790
ssl_options = None
689791
ssl_context = None
792+
tls_session_cache = None
793+
session_reused = False
690794
last_error = None
691795

692796
# The current number of operations that are in flight. More precisely,
@@ -763,14 +867,16 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
763867
ssl_options=None, sockopts=None, compression: Union[bool, str] = True,
764868
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
765869
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
766-
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None,
870+
ssl_context=None, tls_session_cache=None, owning_pool=None, shard_id=None, total_shards=None,
767871
on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None):
768872
# TODO next major rename host to endpoint and remove port kwarg.
769873
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
770874

771875
self.authenticator = authenticator
772876
self.ssl_options = ssl_options.copy() if ssl_options else {}
773877
self.ssl_context = ssl_context
878+
self.tls_session_cache = tls_session_cache
879+
self.session_reused = False
774880
self.sockopts = sockopts
775881
self.compression = compression
776882
self.cql_version = cql_version
@@ -913,7 +1019,28 @@ def _wrap_socket_from_context(self):
9131019
server_hostname = self.endpoint.address
9141020
opts['server_hostname'] = server_hostname
9151021

916-
return self.ssl_context.wrap_socket(self._socket, **opts)
1022+
# Try to get a cached TLS session for resumption
1023+
if self.tls_session_cache:
1024+
cached_session = self.tls_session_cache.get_session(
1025+
self.endpoint.address, self.endpoint.port)
1026+
if cached_session:
1027+
opts['session'] = cached_session
1028+
log.debug("Using cached TLS session for %s:%s",
1029+
self.endpoint.address, self.endpoint.port)
1030+
1031+
ssl_socket = self.ssl_context.wrap_socket(self._socket, **opts)
1032+
1033+
# Store the session for future reuse
1034+
if self.tls_session_cache and ssl_socket.session:
1035+
self.tls_session_cache.set_session(
1036+
self.endpoint.address, self.endpoint.port, ssl_socket.session)
1037+
# Track if the session was reused
1038+
self.session_reused = ssl_socket.session_reused
1039+
if self.session_reused:
1040+
log.debug("TLS session was reused for %s:%s",
1041+
self.endpoint.address, self.endpoint.port)
1042+
1043+
return ssl_socket
9171044

9181045
def _initiate_connection(self, sockaddr):
9191046
if self.features.shard_id is not None:

0 commit comments

Comments
 (0)