Skip to content

Commit a421790

Browse files
committed
Vendor persistent recursive watch Kazoo support
Some changes to Kazoo are needed to support persistent recursive watches. Until those merge upstream, vendor and update the parts of Kazoo we need. Upstream PR: python-zk/kazoo#715 Change-Id: Id6372e4075b5b3ffeeef3e0f4751a71e59001ef9
1 parent e096928 commit a421790

File tree

6 files changed

+398
-2
lines changed

6 files changed

+398
-2
lines changed

nodepool/zk/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,25 @@
99
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1010
# License for the specific language governing permissions and limitations
1111
# under the License.
12+
1213
import logging
1314
import time
1415
from abc import ABCMeta
1516
from threading import Thread
1617

17-
from kazoo.client import KazooClient
18+
import kazoo.client
19+
from nodepool.zk.vendor.client import ZuulKazooClient
20+
from nodepool.zk.vendor.connection import ZuulConnectionHandler
1821
from kazoo.handlers.threading import KazooTimeoutError
1922
from kazoo.protocol.states import KazooState
2023

2124
from nodepool.zk.exceptions import NoClientException
2225
from nodepool.zk.handler import PoolSequentialThreadingHandler
2326

2427

28+
kazoo.client.ConnectionHandler = ZuulConnectionHandler
29+
30+
2531
class ZooKeeperClient(object):
2632
log = logging.getLogger("nodepool.zk.ZooKeeperClient")
2733

@@ -135,7 +141,7 @@ def connect(self):
135141
args['keyfile'] = self.tls_key
136142
args['certfile'] = self.tls_cert
137143
args['ca'] = self.tls_ca
138-
self.client = KazooClient(**args)
144+
self.client = ZuulKazooClient(**args)
139145
self.client.add_listener(self._connectionListener)
140146
# Manually retry initial connection attempt
141147
while True:

nodepool/zk/vendor/__init__.py

Whitespace-only changes.

nodepool/zk/vendor/client.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# This file is derived from the Kazoo project
2+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
3+
# not use this file except in compliance with the License. You may obtain
4+
# a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
# License for the specific language governing permissions and limitations
12+
# under the License.
13+
14+
from collections import defaultdict
15+
16+
from kazoo.client import (
17+
_prefix_root,
18+
KazooClient,
19+
)
20+
from kazoo.protocol.states import (
21+
Callback,
22+
EventType,
23+
WatchedEvent,
24+
)
25+
26+
from nodepool.zk.vendor.serialization import AddWatch
27+
28+
29+
class ZuulKazooClient(KazooClient):
30+
def __init__(self, *args, **kw):
31+
self._persistent_watchers = defaultdict(set)
32+
self._persistent_recursive_watchers = defaultdict(set)
33+
super().__init__(*args, **kw)
34+
35+
def _reset_watchers(self):
36+
watchers = []
37+
for child_watchers in self._child_watchers.values():
38+
watchers.extend(child_watchers)
39+
40+
for data_watchers in self._data_watchers.values():
41+
watchers.extend(data_watchers)
42+
43+
for persistent_watchers in self._persistent_watchers.values():
44+
watchers.extend(persistent_watchers)
45+
46+
for pr_watchers in self._persistent_recursive_watchers.values():
47+
watchers.extend(pr_watchers)
48+
49+
self._child_watchers = defaultdict(set)
50+
self._data_watchers = defaultdict(set)
51+
self._persistent_watchers = defaultdict(set)
52+
self._persistent_recursive_watchers = defaultdict(set)
53+
54+
ev = WatchedEvent(EventType.NONE, self._state, None)
55+
for watch in watchers:
56+
self.handler.dispatch_callback(Callback("watch", watch, (ev,)))
57+
58+
def add_watch(self, path, watch, mode):
59+
"""Add a watch.
60+
61+
This method adds persistent watches. Unlike the data and
62+
child watches which may be set by calls to
63+
:meth:`KazooClient.exists`, :meth:`KazooClient.get`, and
64+
:meth:`KazooClient.get_children`, persistent watches are not
65+
removed after being triggered.
66+
67+
To remove a persistent watch, use
68+
:meth:`KazooClient.remove_all_watches` with an argument of
69+
:attr:`~kazoo.states.WatcherType.ANY`.
70+
71+
The `mode` argument determines whether or not the watch is
72+
recursive. To set a persistent watch, use
73+
:class:`~kazoo.states.AddWatchMode.PERSISTENT`. To set a
74+
persistent recursive watch, use
75+
:class:`~kazoo.states.AddWatchMode.PERSISTENT_RECURSIVE`.
76+
77+
:param path: Path of node to watch.
78+
:param watch: Watch callback to set for future changes
79+
to this path.
80+
:param mode: The mode to use.
81+
:type mode: int
82+
83+
:raises:
84+
:exc:`~kazoo.exceptions.MarshallingError` if mode is
85+
unknown.
86+
87+
:exc:`~kazoo.exceptions.ZookeeperError` if the server
88+
returns a non-zero error code.
89+
"""
90+
return self.add_watch_async(path, watch, mode).get()
91+
92+
def add_watch_async(self, path, watch, mode):
93+
"""Asynchronously add a watch. Takes the same arguments as
94+
:meth:`add_watch`.
95+
"""
96+
if not isinstance(path, str):
97+
raise TypeError("Invalid type for 'path' (string expected)")
98+
if not callable(watch):
99+
raise TypeError("Invalid type for 'watch' (must be a callable)")
100+
if not isinstance(mode, int):
101+
raise TypeError("Invalid type for 'mode' (int expected)")
102+
103+
async_result = self.handler.async_result()
104+
self._call(
105+
AddWatch(_prefix_root(self.chroot, path), watch, mode),
106+
async_result,
107+
)
108+
return async_result

nodepool/zk/vendor/connection.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# This file is derived from the Kazoo project
2+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
3+
# not use this file except in compliance with the License. You may obtain
4+
# a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
# License for the specific language governing permissions and limitations
12+
# under the License.
13+
14+
from kazoo.exceptions import (
15+
EXCEPTIONS,
16+
NoNodeError,
17+
)
18+
from kazoo.loggingsupport import BLATHER
19+
from kazoo.protocol.connection import (
20+
ConnectionHandler,
21+
CREATED_EVENT,
22+
DELETED_EVENT,
23+
CHANGED_EVENT,
24+
CHILD_EVENT,
25+
CLOSE_RESPONSE,
26+
)
27+
from kazoo.protocol.serialization import (
28+
Close,
29+
Exists,
30+
Transaction,
31+
GetChildren,
32+
GetChildren2,
33+
Watch,
34+
)
35+
from kazoo.protocol.states import (
36+
Callback,
37+
WatchedEvent,
38+
EVENT_TYPE_MAP,
39+
)
40+
41+
from nodepool.zk.vendor.serialization import (
42+
AddWatch,
43+
RemoveWatches,
44+
)
45+
from nodepool.zk.vendor.states import (
46+
AddWatchMode,
47+
WatcherType,
48+
)
49+
50+
51+
class ZuulConnectionHandler(ConnectionHandler):
52+
def _find_persistent_recursive_watchers(self, path):
53+
parts = path.split("/")
54+
watchers = []
55+
for count in range(len(parts)):
56+
candidate = "/".join(parts[: count + 1])
57+
if not candidate:
58+
candidate = '/'
59+
watchers.extend(
60+
self.client._persistent_recursive_watchers.get(candidate, [])
61+
)
62+
return watchers
63+
64+
def _read_watch_event(self, buffer, offset):
65+
client = self.client
66+
watch, offset = Watch.deserialize(buffer, offset)
67+
path = watch.path
68+
69+
self.logger.debug("Received EVENT: %s", watch)
70+
71+
watchers = []
72+
73+
if watch.type in (CREATED_EVENT, CHANGED_EVENT):
74+
watchers.extend(client._data_watchers.pop(path, []))
75+
watchers.extend(client._persistent_watchers.get(path, []))
76+
watchers.extend(self._find_persistent_recursive_watchers(path))
77+
elif watch.type == DELETED_EVENT:
78+
watchers.extend(client._data_watchers.pop(path, []))
79+
watchers.extend(client._child_watchers.pop(path, []))
80+
watchers.extend(client._persistent_watchers.get(path, []))
81+
watchers.extend(self._find_persistent_recursive_watchers(path))
82+
elif watch.type == CHILD_EVENT:
83+
watchers.extend(client._child_watchers.pop(path, []))
84+
else:
85+
self.logger.warn("Received unknown event %r", watch.type)
86+
return
87+
88+
# Strip the chroot if needed
89+
path = client.unchroot(path)
90+
ev = WatchedEvent(EVENT_TYPE_MAP[watch.type], client._state, path)
91+
92+
# Last check to ignore watches if we've been stopped
93+
if client._stopped.is_set():
94+
return
95+
96+
# Dump the watchers to the watch thread
97+
for watch in watchers:
98+
client.handler.dispatch_callback(Callback("watch", watch, (ev,)))
99+
100+
def _read_response(self, header, buffer, offset):
101+
client = self.client
102+
request, async_object, xid = client._pending.popleft()
103+
if header.zxid and header.zxid > 0:
104+
client.last_zxid = header.zxid
105+
if header.xid != xid:
106+
exc = RuntimeError(
107+
"xids do not match, expected %r " "received %r",
108+
xid,
109+
header.xid,
110+
)
111+
async_object.set_exception(exc)
112+
raise exc
113+
114+
# Determine if its an exists request and a no node error
115+
exists_error = (
116+
header.err == NoNodeError.code and request.type == Exists.type
117+
)
118+
119+
# Set the exception if its not an exists error
120+
if header.err and not exists_error:
121+
callback_exception = EXCEPTIONS[header.err]()
122+
self.logger.debug(
123+
"Received error(xid=%s) %r", xid, callback_exception
124+
)
125+
if async_object:
126+
async_object.set_exception(callback_exception)
127+
elif request and async_object:
128+
if exists_error:
129+
# It's a NoNodeError, which is fine for an exists
130+
# request
131+
async_object.set(None)
132+
else:
133+
try:
134+
response = request.deserialize(buffer, offset)
135+
except Exception as exc:
136+
self.logger.exception(
137+
"Exception raised during deserialization "
138+
"of request: %s",
139+
request,
140+
)
141+
async_object.set_exception(exc)
142+
return
143+
self.logger.debug(
144+
"Received response(xid=%s): %r", xid, response
145+
)
146+
147+
# We special case a Transaction as we have to unchroot things
148+
if request.type == Transaction.type:
149+
response = Transaction.unchroot(client, response)
150+
151+
async_object.set(response)
152+
153+
# Determine if watchers should be registered or unregistered
154+
if not client._stopped.is_set():
155+
watcher = getattr(request, "watcher", None)
156+
if watcher:
157+
if isinstance(request, AddWatch):
158+
if request.mode == AddWatchMode.PERSISTENT:
159+
client._persistent_watchers[request.path].add(
160+
watcher
161+
)
162+
elif request.mode == AddWatchMode.PERSISTENT_RECURSIVE:
163+
client._persistent_recursive_watchers[
164+
request.path
165+
].add(watcher)
166+
elif isinstance(request, (GetChildren, GetChildren2)):
167+
client._child_watchers[request.path].add(watcher)
168+
else:
169+
client._data_watchers[request.path].add(watcher)
170+
if isinstance(request, RemoveWatches):
171+
if request.watcher_type == WatcherType.CHILDREN:
172+
client._child_watchers.pop(request.path, None)
173+
elif request.watcher_type == WatcherType.DATA:
174+
client._data_watchers.pop(request.path, None)
175+
elif request.watcher_type == WatcherType.ANY:
176+
client._child_watchers.pop(request.path, None)
177+
client._data_watchers.pop(request.path, None)
178+
client._persistent_watchers.pop(request.path, None)
179+
client._persistent_recursive_watchers.pop(
180+
request.path, None
181+
)
182+
183+
if isinstance(request, Close):
184+
self.logger.log(BLATHER, "Read close response")
185+
return CLOSE_RESPONSE

nodepool/zk/vendor/serialization.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
2+
# not use this file except in compliance with the License. You may obtain
3+
# a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10+
# License for the specific language governing permissions and limitations
11+
# under the License.
12+
13+
from collections import namedtuple
14+
15+
from kazoo.protocol.serialization import (
16+
int_struct,
17+
write_string,
18+
)
19+
20+
21+
class RemoveWatches(namedtuple("RemoveWatches", "path watcher_type")):
22+
type = 18
23+
24+
def serialize(self):
25+
b = bytearray()
26+
b.extend(write_string(self.path))
27+
b.extend(int_struct.pack(self.watcher_type))
28+
return b
29+
30+
@classmethod
31+
def deserialize(cls, bytes, offset):
32+
return None
33+
34+
35+
class AddWatch(namedtuple("AddWatch", "path watcher mode")):
36+
type = 106
37+
38+
def serialize(self):
39+
b = bytearray()
40+
b.extend(write_string(self.path))
41+
b.extend(int_struct.pack(self.mode))
42+
return b
43+
44+
@classmethod
45+
def deserialize(cls, bytes, offset):
46+
return None

0 commit comments

Comments
 (0)