Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 95 additions & 22 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,6 +3441,57 @@ def _clear_watcher(conn, expiring_weakref):
pass


def _fetch_remaining_pages(connection, query_msg, timeout, fail_on_error=True):
"""
Fetch all pages for a paged query.
Executes the query and fetches all pages if the result is paged.

:param connection: The connection to use for querying
:param query_msg: The QueryMessage to execute (must have fetch_size set for paging)
:param timeout: Timeout for each query operation
:param fail_on_error: If True, raise exceptions on query failure. If False, return (success, result) tuple. Defaults to True (same as connection.wait_for_response)
:return: If fail_on_error=True, returns the result with all parsed_rows combined from all pages.
If fail_on_error=False, returns (success, result) tuple where result has all parsed_rows combined.
"""
# Execute the query to get the first page
response = connection.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error)

# Handle fail_on_error=False case where response is (success, result) tuple
if not fail_on_error:
success, result = response
if not success:
return response # Return (False, exception) tuple
else:
result = response

if not result or not result.paging_state:
return response if not fail_on_error else result

all_rows = list(result.parsed_rows) if result.parsed_rows else []

# Fetch remaining pages
while result and result.paging_state:
query_msg.paging_state = result.paging_state
page_response = connection.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error)

if not fail_on_error:
page_success, page_result = page_response
if not page_success:
return page_response # Return (False, exception) tuple
result = page_result
else:
result = page_response

if result and result.parsed_rows:
all_rows.extend(result.parsed_rows)

# Update the result with all rows
if result:
result.parsed_rows = all_rows

return (True, result) if not fail_on_error else result


class ControlConnection(object):
"""
Internal
Expand Down Expand Up @@ -3638,23 +3689,31 @@ def _try_connect(self, host):
sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS
peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
consistency_level=ConsistencyLevel.ONE,
fetch_size=self._schema_meta_page_size)
local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
(peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout, fail_on_error=False)

if not local_success:
raise local_result

consistency_level=ConsistencyLevel.ONE,
fetch_size=self._schema_meta_page_size)

# Try to execute peers query (might be peers_v2)
# Use fail_on_error=False to handle peers_v2 fallback gracefully
peers_success, peers_result = _fetch_remaining_pages(connection, peers_query, self._timeout, fail_on_error=False)
if not peers_success:
# error with the peers v2 query, fallback to peers v1
self._uses_peers_v2 = False
sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
peers_result = connection.wait_for_response(
peers_query, timeout=self._timeout)
consistency_level=ConsistencyLevel.ONE,
fetch_size=self._schema_meta_page_size)
peers_result = _fetch_remaining_pages(connection, peers_query, self._timeout)

# Fetch local query (note: system.local always has exactly 1 row, so it will never have additional pages)
# Use fail_on_error=False to match original behavior
local_success, local_result = _fetch_remaining_pages(connection, local_query, self._timeout, fail_on_error=False)

if not local_success:
raise local_result

shared_results = (peers_result, local_result)
self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results)
Expand Down Expand Up @@ -3797,11 +3856,17 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
log.debug("[control connection] Refreshing node list and token map")
sel_local = self._SELECT_LOCAL
peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout),
consistency_level=cl)
consistency_level=cl,
fetch_size=self._schema_meta_page_size)
local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout),
consistency_level=cl)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout)
consistency_level=cl,
fetch_size=self._schema_meta_page_size)

# Fetch all pages for both queries
# Note: system.local always has exactly 1 row, so it will never have additional pages
# system.peers might have multiple pages for very large clusters (>1000 nodes)
peers_result = _fetch_remaining_pages(connection, peers_query, self._timeout)
local_result = _fetch_remaining_pages(connection, local_query, self._timeout)

peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows)

Expand Down Expand Up @@ -3856,9 +3921,11 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
# in system.local. See CASSANDRA-9436.
local_rpc_address_query = QueryMessage(
query=maybe_add_timeout_to_query(self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
success, local_rpc_address_result = connection.wait_for_response(
local_rpc_address_query, timeout=self._timeout, fail_on_error=False)
consistency_level=ConsistencyLevel.ONE,
fetch_size=self._schema_meta_page_size)
# Fetch all pages (system.local table always contains exactly one row, so this is effectively a no-op)
success, local_rpc_address_result = _fetch_remaining_pages(
connection, local_rpc_address_query, self._timeout, fail_on_error=False)
if success:
row = dict_factory(
local_rpc_address_result.column_names,
Expand Down Expand Up @@ -4092,13 +4159,19 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai

while elapsed < total_timeout:
peers_query = QueryMessage(query=maybe_add_timeout_to_query(select_peers_query, self._metadata_request_timeout),
consistency_level=cl)
consistency_level=cl,
fetch_size=self._schema_meta_page_size)
local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout),
consistency_level=cl)
consistency_level=cl,
fetch_size=self._schema_meta_page_size)
try:
timeout = min(self._timeout, total_timeout - elapsed)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=timeout)

# Fetch all pages if there are more results
# Note: system.local always has exactly 1 row, so it will never have additional pages
# system.peers might have multiple pages for very large clusters (>1000 nodes)
peers_result = _fetch_remaining_pages(connection, peers_query, timeout)
local_result = _fetch_remaining_pages(connection, local_query, timeout)
except OperationTimedOut as timeout:
log.debug("[control connection] Timed out waiting for "
"response during schema agreement check: %s", timeout)
Expand Down
82 changes: 79 additions & 3 deletions tests/unit/test_control_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import unittest

from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, ANY, call
from unittest.mock import Mock, ANY, call, MagicMock

from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType, ConsistencyLevel
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS, QueryMessage
from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile
from cassandra.pool import Host
from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory
Expand Down Expand Up @@ -167,6 +167,20 @@ def __init__(self):
["192.168.1.2", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]]
]
self.wait_for_responses = Mock(return_value=_node_meta_results(self.local_results, self.peer_results))
# Set up wait_for_response to return the appropriate result based on the query
def wait_for_response_side_effect(query_msg, timeout=None, fail_on_error=True):
# Create a result that matches the expected format
result = ResultMessage(kind=RESULT_KIND_ROWS)
# Return peer or local results based on some simple heuristic
if "peers" in query_msg.query.lower():
result.column_names = self.peer_results[0]
result.parsed_rows = self.peer_results[1]
else:
result.column_names = self.local_results[0]
result.parsed_rows = self.local_results[1]
result.paging_state = None
return result
self.wait_for_response = Mock(side_effect=wait_for_response_side_effect)


class FakeTime(object):
Expand Down Expand Up @@ -305,6 +319,68 @@ def test_refresh_nodes_and_tokens(self):

assert self.connection.wait_for_responses.call_count == 1

def test_topology_queries_use_paging(self):
"""
Test that topology queries (system.peers and system.local) use fetch_size parameter
"""
# Test during refresh_node_list_and_token_map
self.control_connection.refresh_node_list_and_token_map()

# Verify that wait_for_response was called (now used instead of wait_for_responses)
assert self.connection.wait_for_response.called

# Get the QueryMessage arguments from the calls
calls = self.connection.wait_for_response.call_args_list

# Verify QueryMessage instances have fetch_size set
for call in calls:
query_msg = call[0][0] # First positional argument
assert isinstance(query_msg, QueryMessage)
assert query_msg.fetch_size == self.control_connection._schema_meta_page_size

def test_topology_queries_fetch_all_pages(self):
"""
Test that topology queries fetch all pages when results are paged
"""
from cassandra.cluster import _fetch_remaining_pages

# Create mock connection
mock_connection = MagicMock()
mock_connection.endpoint = DefaultEndPoint("192.168.1.0")
mock_connection.original_endpoint = mock_connection.endpoint

# Create first page of peers results with paging_state
first_page = ResultMessage(kind=RESULT_KIND_ROWS)
first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"]
first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]]
first_page.paging_state = b"has_more_pages"

# Create second page of peers results without paging_state
second_page = ResultMessage(kind=RESULT_KIND_ROWS)
second_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"]
second_page.parsed_rows = [["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]]
second_page.paging_state = None

# Setup mock: first call returns first page, second call returns second page
mock_connection.wait_for_response.side_effect = [first_page, second_page]

# Test _fetch_remaining_pages
self.control_connection._connection = mock_connection
query_msg = QueryMessage(query="SELECT * FROM system.peers",
consistency_level=ConsistencyLevel.ONE,
fetch_size=self.control_connection._schema_meta_page_size)

result = _fetch_remaining_pages(mock_connection, query_msg, timeout=5)

# Verify that both pages were fetched
assert len(result.parsed_rows) == 2
assert result.parsed_rows[0][0] == "192.168.1.1"
assert result.parsed_rows[1][0] == "192.168.1.2"
assert result.paging_state is None

# Verify wait_for_response was called twice (first page + second page)
assert mock_connection.wait_for_response.call_count == 2

def test_refresh_nodes_and_tokens_with_invalid_peers(self):
def refresh_and_validate_added_hosts():
self.connection.wait_for_responses = Mock(return_value=_node_meta_results(
Expand Down