Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions cassandra/application_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional


class ApplicationInfoBase:
"""
A class that holds application information and adds it to startup message options
"""
def add_startup_options(self, options: dict[str, str]):
raise NotImplementedError()


class ApplicationInfo(ApplicationInfoBase):
application_name: Optional[str]
application_version: Optional[str]
client_id: Optional[str]

def __init__(
self,
application_name: Optional[str] = None,
application_version: Optional[str] = None,
client_id: Optional[str] = None
):
if application_name and not isinstance(application_name, str):
raise TypeError('application_name must be a string')
if application_version and not isinstance(application_version, str):
raise TypeError('application_version must be a string')
if client_id and not isinstance(client_id, str):
raise TypeError('client_id must be a string')

self.application_name = application_name
self.application_version = application_version
self.client_id = client_id

def add_startup_options(self, options: dict[str, str]):
if self.application_name:
options['APPLICATION_NAME'] = self.application_name
if self.application_version:
options['APPLICATION_VERSION'] = self.application_version
if self.client_id:
options['CLIENT_ID'] = self.client_id
23 changes: 23 additions & 0 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from itertools import groupby, count, chain
import json
import logging
from typing import Optional
from warnings import warn
from random import random
import re
Expand Down Expand Up @@ -95,6 +96,7 @@
from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory
from cassandra.datastax import cloud as dscloud
from cassandra.scylla.cloud import CloudConfiguration
from cassandra.application_info import ApplicationInfoBase

try:
from cassandra.io.twistedreactor import TwistedConnection
Expand Down Expand Up @@ -706,6 +708,19 @@
Setting this to :const:`False` disables compression.
"""

_application_info: Optional[ApplicationInfoBase] = None

@property
def application_info(self) -> Optional[ApplicationInfoBase]:
"""
An instance of any subclass of :class:`.application_info.ApplicationInfoBase`.

Defaults to None

When set makes driver sends information about application that uses driver in startup frame
"""
return self._application_info

_auth_provider = None
_auth_provider_callable = None

Expand Down Expand Up @@ -1204,6 +1219,7 @@
shard_aware_options=None,
metadata_request_timeout=None,
column_encryption_policy=None,
application_info:Optional[ApplicationInfoBase]=None
):
"""
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
Expand Down Expand Up @@ -1329,6 +1345,12 @@
raise TypeError("address_translator should not be a class, it should be an instance of that class")
self.address_translator = address_translator

if application_info is not None:
if not isinstance(application_info, ApplicationInfoBase):
raise TypeError(
"application_info should be an instance of any ApplicationInfoBase class")
self._application_info = application_info

if timestamp_generator is not None:
if not callable(timestamp_generator):
raise ValueError("timestamp_generator must be callable")
Expand Down Expand Up @@ -1779,6 +1801,7 @@
kwargs_dict.setdefault('user_type_map', self._user_types)
kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version)
kwargs_dict.setdefault('no_compact', self.no_compact)
kwargs_dict.setdefault('application_info', self.application_info)

return kwargs_dict

Expand Down Expand Up @@ -4466,7 +4489,7 @@
self._scheduled_tasks.discard(task)
fn, args, kwargs = task
kwargs = dict(kwargs)
future = self._executor.submit(fn, *args, **kwargs)

Check failure on line 4492 in cassandra/cluster.py

View workflow job for this annotation

GitHub Actions / test asyncio (3.9)

cannot schedule new futures after shutdown
future.add_done_callback(self._log_if_failed)
else:
self._queue.put_nowait((run_at, i, task))
Expand Down
9 changes: 7 additions & 2 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import weakref
import random
import itertools
from typing import Optional

from cassandra.application_info import ApplicationInfoBase
from cassandra.protocol_features import ProtocolFeatures

if 'gevent.monkey' in sys.modules:
Expand Down Expand Up @@ -761,8 +763,8 @@ class Connection(object):
_is_checksumming_enabled = False

_on_orphaned_stream_released = None

features = None
_application_info: Optional[ApplicationInfoBase] = None

@property
def _iobuf(self):
Expand All @@ -774,7 +776,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None,
on_orphaned_stream_released=None):
on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None):
# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)

Expand All @@ -797,6 +799,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
self._socket_writable = True
self.orphaned_request_ids = set()
self._on_orphaned_stream_released = on_orphaned_stream_released
self._application_info = application_info

if ssl_options:
self.ssl_options.update(self.endpoint.ssl_options or {})
Expand Down Expand Up @@ -1379,6 +1382,8 @@ def _handle_options_response(self, options_response):
self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0]

options = {}
if self._application_info:
self._application_info.add_startup_options(options)
self.features.add_startup_options(options)

if self.cql_version:
Expand Down
93 changes: 93 additions & 0 deletions tests/integration/standard/test_application_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2025 ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from cassandra.application_info import ApplicationInfo
from tests.integration import TestCluster, use_single_node, remove_cluster, xfail_scylla


def setup_module():
use_single_node()


def teardown_module():
remove_cluster()


@xfail_scylla("#scylladb/scylla-enterprise#5467 - not released yet")
class ApplicationInfoTest(unittest.TestCase):
attribute_to_startup_key = {
'application_name': 'APPLICATION_NAME',
'application_version': 'APPLICATION_VERSION',
'client_id': 'CLIENT_ID',
}

def test_create_session_and_check_system_views_clients(self):
"""
Test to ensure that ApplicationInfo user provides endup in `client_options` of `system_views.clients` table
"""

for application_info_args in [
{
'application_name': None,
'application_version': None,
'client_id': None,
},
{
'application_name': 'some-application-name',
'application_version': 'some-application-version',
'client_id': 'some-client-id',
},
{
'application_name': 'some-application-name',
'application_version': None,
'client_id': None,
},
{
'application_name': None,
'application_version': 'some-application-version',
'client_id': None,
},
{
'application_name': None,
'application_version': None,
'client_id': 'some-client-id',
},
]:
with self.subTest(**application_info_args):
try:
cluster = TestCluster(
application_info=ApplicationInfo(
**application_info_args
))

found = False
for row in cluster.connect().execute("select client_options from system_views.clients"):
if not row[0]:
continue
for attribute_key, startup_key in self.attribute_to_startup_key.items():
expected_value = application_info_args.get(attribute_key)
if expected_value:
if row[0].get(startup_key) != expected_value:
break
else:
# Check that it is absent
if row[0].get(startup_key, None) is not None:
break
else:
found = True
assert found
finally:
cluster.shutdown()
Loading