11import time
2- import unittest
2+
33import pytest
4- import os
4+
55from cassandra .cluster import Cluster
66from cassandra .policies import ConstantReconnectionPolicy , RoundRobinPolicy , TokenAwarePolicy
77
88from tests .integration import PROTOCOL_VERSION , use_cluster
99from tests .unit .test_host_connection_pool import LOGGER
1010
11+ CCM_CLUSTER = None
12+
1113def setup_module ():
12- use_cluster ('tablets' , [3 ], start = True )
14+ global CCM_CLUSTER
15+
16+ CCM_CLUSTER = use_cluster ('tablets' , [3 ], start = True )
1317
14- class TestTabletsIntegration (unittest .TestCase ):
18+
19+ class TestTabletsIntegration :
1520 @classmethod
1621 def setup_class (cls ):
1722 cls .cluster = Cluster (contact_points = ["127.0.0.1" , "127.0.0.2" , "127.0.0.3" ], protocol_version = PROTOCOL_VERSION ,
1823 load_balancing_policy = TokenAwarePolicy (RoundRobinPolicy ()),
1924 reconnection_policy = ConstantReconnectionPolicy (1 ))
2025 cls .session = cls .cluster .connect ()
21- cls .create_ks_and_cf (cls )
26+ cls .create_ks_and_cf (cls . session )
2227 cls .create_data (cls .session )
2328
2429 @classmethod
2530 def teardown_class (cls ):
2631 cls .cluster .shutdown ()
2732
28- def verify_same_host_in_tracing (self , results ):
33+ def verify_hosts_in_tracing (self , results , expected ):
2934 traces = results .get_query_trace ()
3035 events = traces .events
3136 host_set = set ()
3237 for event in events :
3338 LOGGER .info ("TRACE EVENT: %s %s %s" , event .source , event .thread_name , event .description )
3439 host_set .add (event .source )
3540
36- self . assertEqual ( len (host_set ), 1 )
37- self . assertIn ( 'locally' , "\n " .join ([event .description for event in events ]) )
41+ assert len (host_set ) == expected
42+ assert 'locally' in "\n " .join ([event .description for event in events ])
3843
3944 trace_id = results .response_future .get_query_trace_ids ()[0 ]
4045 traces = self .session .execute ("SELECT * FROM system_traces.events WHERE session_id = %s" , (trace_id ,))
@@ -44,8 +49,12 @@ def verify_same_host_in_tracing(self, results):
4449 LOGGER .info ("TRACE EVENT: %s %s" , event .source , event .activity )
4550 host_set .add (event .source )
4651
47- self .assertEqual (len (host_set ), 1 )
48- self .assertIn ('locally' , "\n " .join ([event .activity for event in events ]))
52+ assert len (host_set ) == expected
53+ assert 'locally' in "\n " .join ([event .activity for event in events ])
54+
55+ def get_tablet_record (self , query ):
56+ metadata = self .session .cluster .metadata
57+ return metadata ._tablets .get_tablet_for_key (query .keyspace , query .table , metadata .token_map .token_class .from_key (query .routing_key ))
4958
5059 def verify_same_shard_in_tracing (self , results ):
5160 traces = results .get_query_trace ()
@@ -55,8 +64,8 @@ def verify_same_shard_in_tracing(self, results):
5564 LOGGER .info ("TRACE EVENT: %s %s %s" , event .source , event .thread_name , event .description )
5665 shard_set .add (event .thread_name )
5766
58- self . assertEqual ( len (shard_set ), 1 )
59- self . assertIn ( 'locally' , "\n " .join ([event .description for event in events ]) )
67+ assert len (shard_set ) == 1
68+ assert 'locally' in "\n " .join ([event .description for event in events ])
6069
6170 trace_id = results .response_future .get_query_trace_ids ()[0 ]
6271 traces = self .session .execute ("SELECT * FROM system_traces.events WHERE session_id = %s" , (trace_id ,))
@@ -66,27 +75,28 @@ def verify_same_shard_in_tracing(self, results):
6675 LOGGER .info ("TRACE EVENT: %s %s" , event .thread , event .activity )
6776 shard_set .add (event .thread )
6877
69- self . assertEqual ( len (shard_set ), 1 )
70- self . assertIn ( 'locally' , "\n " .join ([event .activity for event in events ]) )
78+ assert len (shard_set ) == 1
79+ assert 'locally' in "\n " .join ([event .activity for event in events ])
7180
72- def create_ks_and_cf (self ):
73- self .session .execute (
81+ @classmethod
82+ def create_ks_and_cf (cls , session ):
83+ session .execute (
7484 """
7585 DROP KEYSPACE IF EXISTS test1
7686 """
7787 )
78- self . session .execute (
88+ session .execute (
7989 """
8090 CREATE KEYSPACE test1
8191 WITH replication = {
8292 'class': 'NetworkTopologyStrategy',
83- 'replication_factor': 1
93+ 'replication_factor': 2
8494 } AND tablets = {
8595 'initial': 8
8696 }
8797 """ )
8898
89- self . session .execute (
99+ session .execute (
90100 """
91101 CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
92102 """ )
@@ -110,7 +120,7 @@ def query_data_shard_select(self, session, verify_in_tracing=True):
110120
111121 bound = prepared .bind ([(2 )])
112122 results = session .execute (bound , trace = True )
113- self . assertEqual ( results , [(2 , 2 , 0 )])
123+ assert results == [(2 , 2 , 0 )]
114124 if verify_in_tracing :
115125 self .verify_same_shard_in_tracing (results )
116126
@@ -122,9 +132,9 @@ def query_data_host_select(self, session, verify_in_tracing=True):
122132
123133 bound = prepared .bind ([(2 )])
124134 results = session .execute (bound , trace = True )
125- self . assertEqual ( results , [(2 , 2 , 0 )])
135+ assert results == [(2 , 2 , 0 )]
126136 if verify_in_tracing :
127- self .verify_same_host_in_tracing (results )
137+ self .verify_hosts_in_tracing (results , 1 )
128138
129139 def query_data_shard_insert (self , session , verify_in_tracing = True ):
130140 prepared = session .prepare (
@@ -146,7 +156,7 @@ def query_data_host_insert(self, session, verify_in_tracing=True):
146156 bound = prepared .bind ([(52 ), (1 ), (2 )])
147157 results = session .execute (bound , trace = True )
148158 if verify_in_tracing :
149- self .verify_same_host_in_tracing (results )
159+ self .verify_hosts_in_tracing (results , 2 )
150160
151161 def test_tablets (self ):
152162 self .query_data_host_select (self .session )
@@ -155,3 +165,70 @@ def test_tablets(self):
155165 def test_tablets_shard_awareness (self ):
156166 self .query_data_shard_select (self .session )
157167 self .query_data_shard_insert (self .session )
168+
169+ def test_tablets_invalidation_drop_ks_while_reconnecting (self ):
170+ def recreate_while_reconnecting (_ ):
171+ # Kill control connection
172+ conn = self .session .cluster .control_connection ._connection
173+ self .session .cluster .control_connection ._connection = None
174+ conn .close ()
175+
176+ # Drop and recreate ks and table to trigger tablets invalidation
177+ self .create_ks_and_cf (self .cluster .connect ())
178+
179+ # Start control connection
180+ self .session .cluster .control_connection ._reconnect ()
181+
182+ self .run_tablets_invalidation_test (recreate_while_reconnecting )
183+
184+ def test_tablets_invalidation_drop_ks (self ):
185+ def drop_ks (_ ):
186+ # Drop and recreate ks and table to trigger tablets invalidation
187+ self .create_ks_and_cf (self .cluster .connect ())
188+ time .sleep (3 )
189+
190+ self .run_tablets_invalidation_test (drop_ks )
191+
192+ @pytest .mark .last
193+ def test_tablets_invalidation_decommission_non_cc_node (self ):
194+ def decommission_non_cc_node (rec ):
195+ # Drop and recreate ks and table to trigger tablets invalidation
196+ for node in CCM_CLUSTER .nodes .values ():
197+ if self .cluster .control_connection ._connection .endpoint .address == node .network_interfaces ["storage" ][0 ]:
198+ # Ignore node that control connection is connected to
199+ continue
200+ for replica in rec .replicas :
201+ if str (replica [0 ]) == str (node .node_hostid ):
202+ node .decommission ()
203+ break
204+ else :
205+ continue
206+ break
207+ else :
208+ assert False , "failed to find node to decommission"
209+ time .sleep (10 )
210+
211+ self .run_tablets_invalidation_test (decommission_non_cc_node )
212+
213+
214+ def run_tablets_invalidation_test (self , invalidate ):
215+ # Make sure driver holds tablet info
216+ # By landing query to the host that is not in replica set
217+ bound = self .session .prepare (
218+ """
219+ SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
220+ """ ).bind ([(2 )])
221+
222+ rec = None
223+ for host in self .cluster .metadata .all_hosts ():
224+ self .session .execute (bound , host = host )
225+ rec = self .get_tablet_record (bound )
226+ if rec is not None :
227+ break
228+
229+ assert rec is not None , "failed to find tablet record"
230+
231+ invalidate (rec )
232+
233+ # Check if tablets information was purged
234+ assert self .get_tablet_record (bound ) is None , "tablet was not deleted, invalidation did not work"
0 commit comments