Skip to content

Commit 3ac9654

Browse files
Copilotmykaul
andcommitted
Add comprehensive tests for TLS session caching
Co-authored-by: mykaul <[email protected]>
1 parent 6f45fb9 commit 3ac9654

File tree

2 files changed

+319
-0
lines changed

2 files changed

+319
-0
lines changed

tests/integration/long/test_ssl.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,107 @@ def test_can_connect_with_sslcontext_default_context(self):
500500
"""
501501
ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS)
502502
validate_ssl_options(ssl_context=ssl_context)
503+
504+
@unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context")
505+
def test_tls_session_cache_enabled_by_default(self):
506+
"""
507+
Test that TLS session caching is enabled by default when SSL is configured.
508+
509+
@since 3.30.0
510+
@expected_result TLS session cache is created and configured
511+
@test_category connection:ssl
512+
"""
513+
ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS)
514+
cluster = TestCluster(
515+
contact_points=[DefaultEndPoint('127.0.0.1')],
516+
ssl_context=ssl_context
517+
)
518+
519+
# Verify session cache was created
520+
self.assertIsNotNone(cluster._tls_session_cache)
521+
self.assertEqual(cluster.tls_session_cache_enabled, True)
522+
self.assertEqual(cluster.tls_session_cache_size, 100)
523+
self.assertEqual(cluster.tls_session_cache_ttl, 3600)
524+
525+
cluster.shutdown()
526+
527+
@unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context")
528+
def test_tls_session_cache_can_be_disabled(self):
529+
"""
530+
Test that TLS session caching can be disabled.
531+
532+
@since 3.30.0
533+
@expected_result TLS session cache is not created when disabled
534+
@test_category connection:ssl
535+
"""
536+
ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS)
537+
cluster = TestCluster(
538+
contact_points=[DefaultEndPoint('127.0.0.1')],
539+
ssl_context=ssl_context,
540+
tls_session_cache_enabled=False
541+
)
542+
543+
# Verify session cache was not created
544+
self.assertIsNone(cluster._tls_session_cache)
545+
self.assertEqual(cluster.tls_session_cache_enabled, False)
546+
547+
cluster.shutdown()
548+
549+
@unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context")
550+
def test_tls_session_reuse(self):
551+
"""
552+
Test that TLS sessions are reused across multiple connections to the same endpoint.
553+
554+
@since 3.30.0
555+
@expected_result Sessions are cached and reused, reducing handshake overhead
556+
@test_category connection:ssl
557+
"""
558+
ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS)
559+
cluster = TestCluster(
560+
contact_points=[DefaultEndPoint('127.0.0.1')],
561+
ssl_context=ssl_context
562+
)
563+
564+
try:
565+
session = cluster.connect(wait_for_all_pools=True)
566+
567+
# Verify session cache was populated
568+
self.assertIsNotNone(cluster._tls_session_cache)
569+
initial_cache_size = cluster._tls_session_cache.size()
570+
self.assertGreater(initial_cache_size, 0, "Session cache should contain sessions after connection")
571+
572+
# Execute a simple query
573+
result = session.execute("SELECT * FROM system.local WHERE key='local'")
574+
self.assertIsNotNone(result)
575+
576+
# Get a connection from the pool to check session_reused flag
577+
# Note: We can't easily check the exact connection that was reused,
578+
# but we can verify the cache has sessions
579+
cache_size = cluster._tls_session_cache.size()
580+
self.assertGreater(cache_size, 0, "Session cache should contain sessions")
581+
582+
finally:
583+
cluster.shutdown()
584+
585+
@unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context")
586+
def test_tls_session_cache_configuration(self):
587+
"""
588+
Test that TLS session cache can be configured with custom parameters.
589+
590+
@since 3.30.0
591+
@expected_result Custom cache configuration is applied
592+
@test_category connection:ssl
593+
"""
594+
ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS)
595+
cluster = TestCluster(
596+
contact_points=[DefaultEndPoint('127.0.0.1')],
597+
ssl_context=ssl_context,
598+
tls_session_cache_size=50,
599+
tls_session_cache_ttl=1800
600+
)
601+
602+
self.assertIsNotNone(cluster._tls_session_cache)
603+
self.assertEqual(cluster.tls_session_cache_size, 50)
604+
self.assertEqual(cluster.tls_session_cache_ttl, 1800)
605+
606+
cluster.shutdown()
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
import unittest
17+
from unittest.mock import Mock
18+
from threading import Thread
19+
20+
from cassandra.connection import TLSSessionCache
21+
22+
23+
class TLSSessionCacheTest(unittest.TestCase):
24+
"""Test the TLSSessionCache implementation."""
25+
26+
def test_cache_basic_operations(self):
27+
"""Test basic get and set operations."""
28+
cache = TLSSessionCache(max_size=10, ttl=60)
29+
30+
# Create a mock session
31+
mock_session = Mock()
32+
33+
# Initially empty
34+
self.assertIsNone(cache.get_session('host1', 9042))
35+
self.assertEqual(cache.size(), 0)
36+
37+
# Set a session
38+
cache.set_session('host1', 9042, mock_session)
39+
self.assertEqual(cache.size(), 1)
40+
41+
# Retrieve the session
42+
retrieved = cache.get_session('host1', 9042)
43+
self.assertEqual(retrieved, mock_session)
44+
45+
def test_cache_different_endpoints(self):
46+
"""Test that different endpoints have separate cache entries."""
47+
cache = TLSSessionCache(max_size=10, ttl=60)
48+
49+
session1 = Mock(name='session1')
50+
session2 = Mock(name='session2')
51+
session3 = Mock(name='session3')
52+
53+
cache.set_session('host1', 9042, session1)
54+
cache.set_session('host2', 9042, session2)
55+
cache.set_session('host1', 9043, session3)
56+
57+
self.assertEqual(cache.size(), 3)
58+
self.assertEqual(cache.get_session('host1', 9042), session1)
59+
self.assertEqual(cache.get_session('host2', 9042), session2)
60+
self.assertEqual(cache.get_session('host1', 9043), session3)
61+
62+
def test_cache_ttl_expiration(self):
63+
"""Test that sessions expire after TTL."""
64+
cache = TLSSessionCache(max_size=10, ttl=1) # 1 second TTL
65+
66+
mock_session = Mock()
67+
cache.set_session('host1', 9042, mock_session)
68+
69+
# Should be retrievable immediately
70+
self.assertIsNotNone(cache.get_session('host1', 9042))
71+
72+
# Wait for expiration
73+
time.sleep(1.1)
74+
75+
# Should be expired
76+
self.assertIsNone(cache.get_session('host1', 9042))
77+
self.assertEqual(cache.size(), 0)
78+
79+
def test_cache_max_size_eviction(self):
80+
"""Test that LRU eviction works when cache is full."""
81+
cache = TLSSessionCache(max_size=3, ttl=60)
82+
83+
session1 = Mock(name='session1')
84+
session2 = Mock(name='session2')
85+
session3 = Mock(name='session3')
86+
session4 = Mock(name='session4')
87+
88+
# Fill cache to capacity
89+
cache.set_session('host1', 9042, session1)
90+
time.sleep(0.01) # Ensure different access times
91+
cache.set_session('host2', 9042, session2)
92+
time.sleep(0.01)
93+
cache.set_session('host3', 9042, session3)
94+
95+
self.assertEqual(cache.size(), 3)
96+
97+
# Access session2 to update its access time
98+
time.sleep(0.01)
99+
cache.get_session('host2', 9042)
100+
101+
# Add a fourth session - should evict session1 (oldest access)
102+
time.sleep(0.01)
103+
cache.set_session('host4', 9042, session4)
104+
105+
self.assertEqual(cache.size(), 3)
106+
self.assertIsNone(cache.get_session('host1', 9042))
107+
self.assertIsNotNone(cache.get_session('host2', 9042))
108+
self.assertIsNotNone(cache.get_session('host3', 9042))
109+
self.assertIsNotNone(cache.get_session('host4', 9042))
110+
111+
def test_cache_clear_expired(self):
112+
"""Test manual clearing of expired sessions."""
113+
cache = TLSSessionCache(max_size=10, ttl=1)
114+
115+
session1 = Mock(name='session1')
116+
session2 = Mock(name='session2')
117+
118+
cache.set_session('host1', 9042, session1)
119+
time.sleep(1.1) # Let session1 expire
120+
cache.set_session('host2', 9042, session2)
121+
122+
# Before clearing, both are in cache
123+
self.assertEqual(cache.size(), 2)
124+
125+
# Clear expired sessions
126+
cache.clear_expired()
127+
128+
# Only session2 should remain
129+
self.assertEqual(cache.size(), 1)
130+
self.assertIsNone(cache.get_session('host1', 9042))
131+
self.assertIsNotNone(cache.get_session('host2', 9042))
132+
133+
def test_cache_clear_all(self):
134+
"""Test clearing all sessions from cache."""
135+
cache = TLSSessionCache(max_size=10, ttl=60)
136+
137+
cache.set_session('host1', 9042, Mock())
138+
cache.set_session('host2', 9042, Mock())
139+
cache.set_session('host3', 9042, Mock())
140+
141+
self.assertEqual(cache.size(), 3)
142+
143+
cache.clear()
144+
145+
self.assertEqual(cache.size(), 0)
146+
147+
def test_cache_none_session(self):
148+
"""Test that None sessions are not cached."""
149+
cache = TLSSessionCache(max_size=10, ttl=60)
150+
151+
cache.set_session('host1', 9042, None)
152+
153+
self.assertEqual(cache.size(), 0)
154+
self.assertIsNone(cache.get_session('host1', 9042))
155+
156+
def test_cache_update_existing_session(self):
157+
"""Test that updating an existing session works correctly."""
158+
cache = TLSSessionCache(max_size=10, ttl=60)
159+
160+
session1 = Mock(name='session1')
161+
session2 = Mock(name='session2')
162+
163+
cache.set_session('host1', 9042, session1)
164+
self.assertEqual(cache.get_session('host1', 9042), session1)
165+
166+
# Update with new session
167+
cache.set_session('host1', 9042, session2)
168+
self.assertEqual(cache.get_session('host1', 9042), session2)
169+
170+
# Size should still be 1
171+
self.assertEqual(cache.size(), 1)
172+
173+
def test_cache_thread_safety(self):
174+
"""Test that cache operations are thread-safe."""
175+
cache = TLSSessionCache(max_size=100, ttl=60)
176+
errors = []
177+
178+
def set_sessions(thread_id):
179+
try:
180+
for i in range(50):
181+
session = Mock(name=f'session_{thread_id}_{i}')
182+
cache.set_session(f'host{thread_id}', 9042 + i, session)
183+
except Exception as e:
184+
errors.append(e)
185+
186+
def get_sessions(thread_id):
187+
try:
188+
for i in range(50):
189+
cache.get_session(f'host{thread_id}', 9042 + i)
190+
except Exception as e:
191+
errors.append(e)
192+
193+
# Create multiple threads doing concurrent operations
194+
threads = []
195+
for i in range(5):
196+
t1 = Thread(target=set_sessions, args=(i,))
197+
t2 = Thread(target=get_sessions, args=(i,))
198+
threads.extend([t1, t2])
199+
200+
for t in threads:
201+
t.start()
202+
203+
for t in threads:
204+
t.join()
205+
206+
# Check that no errors occurred
207+
self.assertEqual(len(errors), 0, f"Thread safety test failed with errors: {errors}")
208+
209+
# Check that cache is not empty and within max size
210+
self.assertGreater(cache.size(), 0)
211+
self.assertLessEqual(cache.size(), 100)
212+
213+
214+
if __name__ == '__main__':
215+
unittest.main()

0 commit comments

Comments
 (0)