Skip to content

Commit 8c02393

Browse files
committed
cluster: add application_info
Implement clustr.application_info to make driver send following startup options to server: 1. `APPLICATION_NAME` - ID what application is using driver, example: repo of the application 2. `APPLICATION_VERSION` - Version of the application, example: release version or commit id of the application 3. `CLIENT_ID` - unique id of the client instance, example: pod name All strings.
1 parent d5834c6 commit 8c02393

File tree

5 files changed

+103
-2
lines changed

5 files changed

+103
-2
lines changed

cassandra/application_info.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
class ApplicationInfoBase(object):
17+
"""
18+
A class that holds application information and adds it to startup message options
19+
"""
20+
def add_startup_options(self, options: dict[str, str]):
21+
raise NotImplementedError()
22+
23+
24+
class ApplicationInfo(ApplicationInfoBase):
25+
def __init__(self, application_name: str = None, application_version: str = None, client_id: str = None):
26+
if application_name:
27+
if not isinstance(application_name, str):
28+
raise TypeError('application_name must be a string')
29+
self.application_name = application_name
30+
if application_version:
31+
if not isinstance(application_version, str):
32+
raise TypeError('application_version must be a string')
33+
self.application_version = application_version
34+
if client_id:
35+
if not isinstance(client_id, str):
36+
raise TypeError('client_id must be a string')
37+
self.client_id = client_id
38+
39+
def add_startup_options(self, options: dict[str, str]):
40+
if self.application_name:
41+
options['APPLICATION_NAME'] = self.application_name
42+
if self.application_version:
43+
options['APPLICATION_VERSION'] = self.application_version
44+
if self.client_id:
45+
options['CLIENT_ID'] = self.client_id

cassandra/cluster.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory
9696
from cassandra.datastax import cloud as dscloud
9797
from cassandra.scylla.cloud import CloudConfiguration
98+
from cassandra.application_info import ApplicationInfoBase
9899

99100
try:
100101
from cassandra.io.twistedreactor import TwistedConnection
@@ -706,6 +707,15 @@ class Cluster(object):
706707
Setting this to :const:`False` disables compression.
707708
"""
708709

710+
application_info: ApplicationInfoBase = None
711+
"""
712+
An instance of any subclass of :class:`.cluster.ApplicationInfoBase`.
713+
714+
Defaults to None
715+
716+
When not None makes driver sends information about application that uses driver in startup frame
717+
"""
718+
709719
_auth_provider = None
710720
_auth_provider_callable = None
711721

@@ -1204,6 +1214,7 @@ def __init__(self,
12041214
shard_aware_options=None,
12051215
metadata_request_timeout=None,
12061216
column_encryption_policy=None,
1217+
application_info:ApplicationInfoBase=None
12071218
):
12081219
"""
12091220
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1329,6 +1340,12 @@ def __init__(self,
13291340
raise TypeError("address_translator should not be a class, it should be an instance of that class")
13301341
self.address_translator = address_translator
13311342

1343+
if application_info is not None:
1344+
if not isinstance(application_info, ApplicationInfoBase):
1345+
raise TypeError(
1346+
"application_info should be an instance of any ApplicationInfoBase class")
1347+
self.application_info = application_info
1348+
13321349
if timestamp_generator is not None:
13331350
if not callable(timestamp_generator):
13341351
raise ValueError("timestamp_generator must be callable")
@@ -1779,6 +1796,7 @@ def _make_connection_kwargs(self, endpoint, kwargs_dict):
17791796
kwargs_dict.setdefault('user_type_map', self._user_types)
17801797
kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version)
17811798
kwargs_dict.setdefault('no_compact', self.no_compact)
1799+
kwargs_dict.setdefault('application_info', self.application_info)
17821800

17831801
return kwargs_dict
17841802

cassandra/connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import random
3030
import itertools
3131

32+
from cassandra.application_info import ApplicationInfoBase
3233
from cassandra.protocol_features import ProtocolFeatures
3334

3435
if 'gevent.monkey' in sys.modules:
@@ -774,7 +775,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
774775
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
775776
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
776777
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None,
777-
on_orphaned_stream_released=None):
778+
on_orphaned_stream_released=None, application_info: ApplicationInfoBase = None):
778779
# TODO next major rename host to endpoint and remove port kwarg.
779780
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
780781

@@ -797,6 +798,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
797798
self._socket_writable = True
798799
self.orphaned_request_ids = set()
799800
self._on_orphaned_stream_released = on_orphaned_stream_released
801+
self._application_info = application_info
800802

801803
if ssl_options:
802804
self.ssl_options.update(self.endpoint.ssl_options or {})
@@ -1380,6 +1382,8 @@ def _handle_options_response(self, options_response):
13801382

13811383
options = {}
13821384
self.features.add_startup_options(options)
1385+
if self._application_info:
1386+
self._application_info.add_startup_options(options)
13831387

13841388
if self.cql_version:
13851389
if self.cql_version not in supported_cql_versions:

tests/integration/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def _id_and_mark(f):
405405
notdse = unittest.skipIf(DSE_VERSION, "DSE not supported")
406406
requiredse = unittest.skipUnless(DSE_VERSION, "DSE required")
407407
requirescloudproxy = unittest.skipIf(CLOUD_PROXY_PATH is None, "Cloud Proxy path hasn't been specified")
408-
408+
notscylla = unittest.skipIf(SCYLLA_VERSION, "Does not support scylla")
409409
libevtest = unittest.skipUnless(EVENT_LOOP_MANAGER=="libev", "Test timing designed for libev loop")
410410

411411
def wait_for_node_socket(node, timeout):
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
3+
from cassandra.application_info import ApplicationInfo
4+
from tests.integration import TestCluster, notscylla
5+
6+
7+
@notscylla
8+
class ApplicationInfoTest(unittest.TestCase):
9+
def setUp(self):
10+
self.cluster = TestCluster(application_info=ApplicationInfo(
11+
application_name='TestApplicationInfo',
12+
application_version='1.0.0-test',
13+
client_id='some-client-id',
14+
))
15+
self.session = self.cluster.connect()
16+
17+
def tearDown(self):
18+
self.cluster.shutdown()
19+
20+
def test_application_info_endup_in_system_views_clients(self):
21+
"""
22+
Test to ensure that ApplicationInfo user provides endup in `client_options` of `system_views.clients` table
23+
"""
24+
found = False
25+
26+
for row in self.session.execute("select client_options from system_views.clients"):
27+
if not row['client_options']:
28+
continue
29+
for marker in ("'APPLICATION_NAME': 'app_name'", "'APPLICATION_VERSION': 'app_version'", "'CLIENT_ID': 'client_id'"):
30+
if marker not in row['client_options']:
31+
break
32+
else:
33+
found = True
34+
assert found

0 commit comments

Comments
 (0)