Skip to content

Commit 055e8ec

Browse files
authored
Merge pull request #287 from sylwiaszunejko/white_list_with_unix_sockets
Add support for unix domain sockets to `WhiteListRoundRobinPolicy`
2 parents 89d6051 + 5dfb81b commit 055e8ec

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

cassandra/policies.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import socket
2020
import warnings
2121
from cassandra import WriteType as WT
22+
from cassandra.connection import UnixSocketEndPoint
2223

2324

2425
# This is done this way because WriteType was originally
@@ -436,8 +437,13 @@ def __init__(self, hosts):
436437
connections to.
437438
"""
438439
self._allowed_hosts = tuple(hosts)
439-
self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts
440-
for endpoint in socket.getaddrinfo(a, None, socket.AF_UNSPEC, socket.SOCK_STREAM)]
440+
self._allowed_hosts_resolved = []
441+
for h in self._allowed_hosts:
442+
if isinstance(h, UnixSocketEndPoint):
443+
self._allowed_hosts_resolved.append(h._unix_socket_path)
444+
else:
445+
self._allowed_hosts_resolved.extend([endpoint[4][0]
446+
for endpoint in socket.getaddrinfo(h, None, socket.AF_UNSPEC, socket.SOCK_STREAM)])
441447

442448
RoundRobinPolicy.__init__(self)
443449

tests/unit/test_policies.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy,
3535
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy)
3636
from cassandra.pool import Host
37-
from cassandra.connection import DefaultEndPoint
37+
from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint
3838
from cassandra.query import Statement
3939

4040
from six.moves import xrange
@@ -1259,6 +1259,17 @@ def test_hosts_with_hostname(self):
12591259
self.assertEqual(sorted(qplan), [host])
12601260

12611261
self.assertEqual(policy.distance(host), HostDistance.LOCAL)
1262+
1263+
def test_hosts_with_socket_hostname(self):
1264+
hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')]
1265+
policy = WhiteListRoundRobinPolicy(hosts)
1266+
host = Host(UnixSocketEndPoint('/tmp/scylla-workdir/cql.m'), SimpleConvictionPolicy)
1267+
policy.populate(None, [host])
1268+
1269+
qplan = list(policy.make_query_plan())
1270+
self.assertEqual(sorted(qplan), [host])
1271+
1272+
self.assertEqual(policy.distance(host), HostDistance.LOCAL)
12621273

12631274

12641275
class AddressTranslatorTest(unittest.TestCase):

0 commit comments

Comments
 (0)