diff --git a/cassandra/metadata.py b/cassandra/metadata.py index bbfaf2605b..85f6c45ac6 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -139,8 +139,9 @@ def export_schema_as_string(self): def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None, metadata_request_timeout=None, **kwargs): - server_version = self.get_host(connection.original_endpoint).release_version - dse_version = self.get_host(connection.original_endpoint).dse_version + host = self.get_host(connection.original_endpoint) + server_version = host.release_version if host else None + dse_version = host.dse_version if host else None parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size) if not target_type: @@ -3409,8 +3410,27 @@ def __init__( self.to_clustering_columns = to_clustering_columns +def get_column_from_system_local(connection, column_name: str, timeout, metadata_request_timeout) -> str: + success, local_result = connection.wait_for_response( + QueryMessage( + query=maybe_add_timeout_to_query( + "SELECT " + column_name + " FROM system.local WHERE key='local'", + metadata_request_timeout), + consistency_level=ConsistencyLevel.ONE) + , timeout=timeout, fail_on_error=False) + if not success or not local_result.parsed_rows: + return "" + local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) + local_row = local_rows[0] + return local_row.get(column_name) + + def get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size=None): - version = Version(server_version) + if server_version is None and dse_version is None: + server_version = get_column_from_system_local(connection, "release_version", timeout, metadata_request_timeout) + dse_version = get_column_from_system_local(connection, "dse_version", timeout, metadata_request_timeout) + + version = Version(server_version or "0") if dse_version: v = Version(dse_version) if v >= Version('6.8.0'): diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index 3069f6bced..c471fab827 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -30,9 +30,11 @@ UserType, KeyspaceMetadata, get_schema_parser, _UnknownStrategy, ColumnMetadata, TableMetadata, IndexMetadata, Function, Aggregate, - Metadata, TokenMap, ReplicationFactor) + Metadata, TokenMap, ReplicationFactor, + SchemaParserDSE68) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host +from cassandra.protocol import QueryMessage from tests.util import assertCountEqual import pytest @@ -616,6 +618,37 @@ def test_build_index_as_cql(self): assert index_meta.as_cql_query() == "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'" +class SchemaParserLookupTests(unittest.TestCase): + + def test_reads_versions_from_system_local_when_missing(self): + connection = Mock() + + release_version_resp = Mock() + release_version_resp.column_names = ["release_version"] + release_version_resp.parsed_rows = [["4.0.0"]] + + dse_version_resp = Mock() + dse_version_resp.column_names = ["dse_version"] + dse_version_resp.parsed_rows = [["6.8.0"]] + + def mock_system_local(query, *args, **kwargs): + if not isinstance(query, QueryMessage): + raise RuntimeError("first argument should be a QueryMessage") + if "release_version" in query.query: + return (True, release_version_resp) + if "dse_version" in query.query: + return (True, dse_version_resp) + raise RuntimeError("unexpected query") + + connection.wait_for_response.side_effect = mock_system_local + + parser = get_schema_parser(connection, None, None, 0.1, None) + + assert isinstance(parser, SchemaParserDSE68) + message = connection.wait_for_response.call_args[0][0] + assert "system.local" in message.query + + class UnicodeIdentifiersTests(unittest.TestCase): """ Exercise cql generation with unicode characters. Keyspace, Table, and Index names