diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5822a23aa9..8370dde9d1 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -29,7 +29,7 @@ from itertools import groupby, count, chain import json import logging -from typing import Optional +from typing import Optional, Union from warnings import warn from random import random import re @@ -51,7 +51,7 @@ from cassandra.connection import (ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported, EndPoint, DefaultEndPoint, DefaultEndPointFactory, - SniEndPointFactory, ConnectionBusy) + SniEndPointFactory, ConnectionBusy, locally_supported_compressions) from cassandra.cqltypes import UserType import cassandra.cqltypes as types from cassandra.encoder import Encoder @@ -686,7 +686,7 @@ class Cluster(object): Used for testing new protocol features incrementally before the new version is complete. """ - compression = True + compression: Union[bool, str] = True """ Controls compression for communications between the driver and Cassandra. If left as the default of :const:`True`, either lz4 or snappy compression @@ -1173,7 +1173,7 @@ def token_metadata_enabled(self, enabled): def __init__(self, contact_points=_NOT_SET, port=9042, - compression=True, + compression: Union[bool, str] = True, auth_provider=None, load_balancing_policy=None, reconnection_policy=None, @@ -1302,6 +1302,24 @@ def __init__(self, self._resolve_hostnames() + if isinstance(compression, bool): + if compression and not locally_supported_compressions: + log.error( + "Compression is enabled, but no compression libraries are available. " + "Disabling compression, consider installing one of the Python packages: lz4 and/or python-snappy." + ) + compression = False + elif isinstance(compression, str): + if not locally_supported_compressions.get(compression): + raise ValueError( + "Compression '%s' was requested, but it is not available. " + "Consider installing the corresponding Python package." % compression + ) + else: + raise TypeError( + "The 'compression' option must be either a string (e.g., 'lz4' or 'snappy') " + "or a boolean (True to enable any available compression, False to disable it)." + ) self.compression = compression if protocol_version is not _NOT_SET: diff --git a/cassandra/connection.py b/cassandra/connection.py index 39baeea884..9ac02c9776 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -28,7 +28,7 @@ import weakref import random import itertools -from typing import Optional +from typing import Optional, Union from cassandra.application_info import ApplicationInfoBase from cassandra.protocol_features import ProtocolFeatures @@ -679,7 +679,7 @@ class Connection(object): protocol_version = ProtocolVersion.MAX_SUPPORTED keyspace = None - compression = True + compression: Union[bool, str] = True _compression_type = None compressor = None decompressor = None @@ -760,7 +760,7 @@ def _iobuf(self): return self._io_buffer.io_buffer def __init__(self, host='127.0.0.1', port=9042, authenticator=None, - ssl_options=None, sockopts=None, compression=True, + ssl_options=None, sockopts=None, compression: Union[bool, str] = True, cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, ssl_context=None, owning_pool=None, shard_id=None, total_shards=None, @@ -1383,10 +1383,11 @@ def _handle_options_response(self, options_response): overlap = (set(locally_supported_compressions.keys()) & set(remote_supported_compressions)) if len(overlap) == 0: - log.error("No available compression types supported on both ends." - " locally supported: %r. remotely supported: %r", - locally_supported_compressions.keys(), - remote_supported_compressions) + if locally_supported_compressions: + log.error("No available compression types supported on both ends." + " locally supported: %r. remotely supported: %r", + locally_supported_compressions.keys(), + remote_supported_compressions) else: compression_type = None if isinstance(self.compression, str):