Skip to content

Commit fa1f654

Browse files
committed
Add minimalistic implementation for Push notifications parsing with hiredis-py
1 parent 36619a5 commit fa1f654

File tree

14 files changed

+128
-59
lines changed

14 files changed

+128
-59
lines changed

redis/_parsers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import BaseParser, _AsyncRESPBase
1+
from .base import AsyncPushNotificationsParser, BaseParser, PushNotificationsParser, _AsyncRESPBase
22
from .commands import AsyncCommandsParser, CommandsParser
33
from .encoders import Encoder
44
from .hiredis import _AsyncHiredisParser, _HiredisParser
@@ -11,10 +11,12 @@
1111
"_AsyncRESPBase",
1212
"_AsyncRESP2Parser",
1313
"_AsyncRESP3Parser",
14+
"AsyncPushNotificationsParser",
1415
"CommandsParser",
1516
"Encoder",
1617
"BaseParser",
1718
"_HiredisParser",
1819
"_RESP2Parser",
1920
"_RESP3Parser",
21+
"PushNotificationsParser",
2022
]

redis/_parsers/base.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
from abc import ABC
33
from asyncio import IncompleteReadError, StreamReader, TimeoutError
4-
from typing import List, Optional, Union
4+
from typing import Callable, List, Optional, Protocol, Union
55

66
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
77
from asyncio import timeout as async_timeout
@@ -158,6 +158,58 @@ async def read_response(
158158
raise NotImplementedError()
159159

160160

161+
_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
162+
163+
164+
class PushNotificationsParser(Protocol):
165+
"""Protocol defining RESP3-specific parsing functionality"""
166+
167+
pubsub_push_handler_func: Callable
168+
invalidation_push_handler_func: Optional[Callable] = None
169+
170+
def handle_pubsub_push_response(self, response):
171+
"""Handle pubsub push responses"""
172+
raise NotImplementedError()
173+
174+
def handle_push_response(self, response, **kwargs):
175+
if response[0] not in _INVALIDATION_MESSAGE:
176+
return self.pubsub_push_handler_func(response)
177+
if self.invalidation_push_handler_func:
178+
return self.invalidation_push_handler_func(response)
179+
180+
def set_pubsub_push_handler(self, pubsub_push_handler_func):
181+
self.pubsub_push_handler_func = pubsub_push_handler_func
182+
183+
def set_invalidation_push_handler(self, invalidation_push_handler_func):
184+
self.invalidation_push_handler_func = invalidation_push_handler_func
185+
186+
187+
class AsyncPushNotificationsParser(Protocol):
188+
"""Protocol defining async RESP3-specific parsing functionality"""
189+
190+
pubsub_push_handler_func: Callable
191+
invalidation_push_handler_func: Optional[Callable] = None
192+
193+
async def handle_pubsub_push_response(self, response):
194+
"""Handle pubsub push responses asynchronously"""
195+
...
196+
197+
async def handle_push_response(self, response, **kwargs):
198+
"""Handle push responses asynchronously"""
199+
if response[0] not in _INVALIDATION_MESSAGE:
200+
return await self.pubsub_push_handler_func(response)
201+
if self.invalidation_push_handler_func:
202+
return await self.invalidation_push_handler_func(response)
203+
204+
def set_pubsub_push_handler(self, pubsub_push_handler_func):
205+
"""Set the pubsub push handler function"""
206+
self.pubsub_push_handler_func = pubsub_push_handler_func
207+
208+
def set_invalidation_push_handler(self, invalidation_push_handler_func):
209+
"""Set the invalidation push handler function"""
210+
self.invalidation_push_handler_func = invalidation_push_handler_func
211+
212+
161213
class _AsyncRESPBase(AsyncBaseParser):
162214
"""Base class for async resp parsing"""
163215

redis/_parsers/hiredis.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import socket
33
import sys
4+
from logging import getLogger
45
from typing import Callable, List, Optional, TypedDict, Union
56

67
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
@@ -11,7 +12,12 @@
1112
from ..exceptions import ConnectionError, InvalidResponse, RedisError
1213
from ..typing import EncodableT
1314
from ..utils import HIREDIS_AVAILABLE
14-
from .base import AsyncBaseParser, BaseParser
15+
from .base import (
16+
AsyncBaseParser,
17+
BaseParser,
18+
PushNotificationsParser,
19+
AsyncPushNotificationsParser,
20+
)
1521
from .socket import (
1622
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
1723
NONBLOCKING_EXCEPTIONS,
@@ -32,21 +38,29 @@ class _HiredisReaderArgs(TypedDict, total=False):
3238
errors: Optional[str]
3339

3440

35-
class _HiredisParser(BaseParser):
41+
class _HiredisParser(BaseParser, PushNotificationsParser):
3642
"Parser class for connections using Hiredis"
3743

3844
def __init__(self, socket_read_size):
3945
if not HIREDIS_AVAILABLE:
4046
raise RedisError("Hiredis is not installed")
4147
self.socket_read_size = socket_read_size
4248
self._buffer = bytearray(socket_read_size)
49+
self.pubsub_push_handler_func = self.handle_pubsub_push_response
50+
self.invalidation_push_handler_func = None
51+
self._hiredis_PushNotificationType = None
4352

4453
def __del__(self):
4554
try:
4655
self.on_disconnect()
4756
except Exception:
4857
pass
4958

59+
def handle_pubsub_push_response(self, response):
60+
logger = getLogger("push_response")
61+
logger.debug("Push response: " + str(response))
62+
return response
63+
5064
def on_connect(self, connection, **kwargs):
5165
import hiredis
5266

@@ -64,6 +78,12 @@ def on_connect(self, connection, **kwargs):
6478
self._reader = hiredis.Reader(**kwargs)
6579
self._next_response = NOT_ENOUGH_DATA
6680

81+
try:
82+
self._hiredis_PushNotificationType = hiredis.PushNotification
83+
except AttributeError:
84+
# hiredis < 3.2
85+
self._hiredis_PushNotificationType = None
86+
6787
def on_disconnect(self):
6888
self._sock = None
6989
self._reader = None
@@ -109,7 +129,7 @@ def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
109129
if custom_timeout:
110130
sock.settimeout(self._socket_timeout)
111131

112-
def read_response(self, disable_decoding=False):
132+
def read_response(self, disable_decoding=False, push_request=False):
113133
if not self._reader:
114134
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
115135

@@ -135,6 +155,16 @@ def read_response(self, disable_decoding=False):
135155
# happened
136156
if isinstance(response, ConnectionError):
137157
raise response
158+
elif self._hiredis_PushNotificationType is not None and isinstance(
159+
response, self._hiredis_PushNotificationType
160+
):
161+
response = self.handle_push_response(response)
162+
if not push_request:
163+
return self.read_response(
164+
disable_decoding=disable_decoding, push_request=push_request
165+
)
166+
else:
167+
return response
138168
elif (
139169
isinstance(response, list)
140170
and response
@@ -154,6 +184,8 @@ def __init__(self, socket_read_size: int):
154184
raise RedisError("Hiredis is not available.")
155185
super().__init__(socket_read_size=socket_read_size)
156186
self._reader = None
187+
self.pubsub_push_handler_func = self.handle_pubsub_push_response
188+
self.invalidation_push_handler_func = None
157189

158190
def on_connect(self, connection):
159191
import hiredis
@@ -171,6 +203,14 @@ def on_connect(self, connection):
171203
self._reader = hiredis.Reader(**kwargs)
172204
self._connected = True
173205

206+
try:
207+
self._hiredis_PushNotificationType = getattr(
208+
hiredis, "PushNotificationType", None
209+
)
210+
except AttributeError:
211+
# hiredis < 3.2
212+
self._hiredis_PushNotificationType = None
213+
174214
def on_disconnect(self):
175215
self._connected = False
176216

@@ -195,7 +235,7 @@ async def read_from_socket(self):
195235
return True
196236

197237
async def read_response(
198-
self, disable_decoding: bool = False
238+
self, disable_decoding: bool = False, push_request: bool = False
199239
) -> Union[EncodableT, List[EncodableT]]:
200240
# If `on_disconnect()` has been called, prohibit any more reads
201241
# even if they could happen because data might be present.
@@ -219,6 +259,16 @@ async def read_response(
219259
# happened
220260
if isinstance(response, ConnectionError):
221261
raise response
262+
elif self._hiredis_PushNotificationType is not None and isinstance(
263+
response, self._hiredis_PushNotificationType
264+
):
265+
response = await self.handle_push_response(response)
266+
if not push_request:
267+
return await self.read_response(
268+
disable_decoding=disable_decoding, push_request=push_request
269+
)
270+
else:
271+
return response
222272
elif (
223273
isinstance(response, list)
224274
and response

redis/_parsers/resp3.py

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33

44
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
55
from ..typing import EncodableT
6-
from .base import _AsyncRESPBase, _RESPBase
6+
from .base import (
7+
AsyncPushNotificationsParser,
8+
PushNotificationsParser,
9+
_AsyncRESPBase,
10+
_RESPBase,
11+
)
712
from .socket import SERVER_CLOSED_CONNECTION_ERROR
813

9-
_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
1014

11-
12-
class _RESP3Parser(_RESPBase):
15+
class _RESP3Parser(_RESPBase, PushNotificationsParser):
1316
"""RESP3 protocol implementation"""
1417

1518
def __init__(self, socket_read_size):
@@ -113,9 +116,7 @@ def _read_response(self, disable_decoding=False, push_request=False):
113116
)
114117
for _ in range(int(response))
115118
]
116-
response = self.handle_push_response(
117-
response, disable_decoding, push_request
118-
)
119+
response = self.handle_push_response(response)
119120
if not push_request:
120121
return self._read_response(
121122
disable_decoding=disable_decoding, push_request=push_request
@@ -129,20 +130,8 @@ def _read_response(self, disable_decoding=False, push_request=False):
129130
response = self.encoder.decode(response)
130131
return response
131132

132-
def handle_push_response(self, response, disable_decoding, push_request):
133-
if response[0] not in _INVALIDATION_MESSAGE:
134-
return self.pubsub_push_handler_func(response)
135-
if self.invalidation_push_handler_func:
136-
return self.invalidation_push_handler_func(response)
137-
138-
def set_pubsub_push_handler(self, pubsub_push_handler_func):
139-
self.pubsub_push_handler_func = pubsub_push_handler_func
140-
141-
def set_invalidation_push_handler(self, invalidation_push_handler_func):
142-
self.invalidation_push_handler_func = invalidation_push_handler_func
143-
144133

145-
class _AsyncRESP3Parser(_AsyncRESPBase):
134+
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
146135
def __init__(self, socket_read_size):
147136
super().__init__(socket_read_size)
148137
self.pubsub_push_handler_func = self.handle_pubsub_push_response
@@ -253,9 +242,7 @@ async def _read_response(
253242
)
254243
for _ in range(int(response))
255244
]
256-
response = await self.handle_push_response(
257-
response, disable_decoding, push_request
258-
)
245+
response = await self.handle_push_response(response)
259246
if not push_request:
260247
return await self._read_response(
261248
disable_decoding=disable_decoding, push_request=push_request
@@ -268,15 +255,3 @@ async def _read_response(
268255
if isinstance(response, bytes) and disable_decoding is False:
269256
response = self.encoder.decode(response)
270257
return response
271-
272-
async def handle_push_response(self, response, disable_decoding, push_request):
273-
if response[0] not in _INVALIDATION_MESSAGE:
274-
return await self.pubsub_push_handler_func(response)
275-
if self.invalidation_push_handler_func:
276-
return await self.invalidation_push_handler_func(response)
277-
278-
def set_pubsub_push_handler(self, pubsub_push_handler_func):
279-
self.pubsub_push_handler_func = pubsub_push_handler_func
280-
281-
def set_invalidation_push_handler(self, invalidation_push_handler_func):
282-
self.invalidation_push_handler_func = invalidation_push_handler_func

redis/asyncio/client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
)
7171
from redis.typing import ChannelT, EncodableT, KeyT
7272
from redis.utils import (
73-
HIREDIS_AVAILABLE,
7473
SSL_AVAILABLE,
7574
_set_info_logger,
7675
deprecated_args,
@@ -938,7 +937,7 @@ async def connect(self):
938937
self.connection.register_connect_callback(self.on_connect)
939938
else:
940939
await self.connection.connect()
941-
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
940+
if self.push_handler_func is not None:
942941
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
943942

944943
self._event_dispatcher.dispatch(

redis/asyncio/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ async def read_response(
590590
response = await self._parser.read_response(
591591
disable_decoding=disable_decoding
592592
)
593-
elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
593+
elif self.protocol in ["3", 3]:
594594
response = await self._parser.read_response(
595595
disable_decoding=disable_decoding, push_request=push_request
596596
)

redis/client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from redis.lock import Lock
5959
from redis.retry import Retry
6060
from redis.utils import (
61-
HIREDIS_AVAILABLE,
6261
_set_info_logger,
6362
deprecated_args,
6463
get_lib_version,
@@ -861,7 +860,7 @@ def execute_command(self, *args):
861860
# register a callback that re-subscribes to any channels we
862861
# were listening to when we were disconnected
863862
self.connection.register_connect_callback(self.on_connect)
864-
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
863+
if self.push_handler_func is not None:
865864
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
866865
self._event_dispatcher.dispatch(
867866
AfterPubSubConnectionInstantiationEvent(

redis/cluster.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from redis.lock import Lock
5252
from redis.retry import Retry
5353
from redis.utils import (
54-
HIREDIS_AVAILABLE,
5554
deprecated_args,
5655
dict_merge,
5756
list_keys_to_dict,
@@ -1999,7 +1998,7 @@ def execute_command(self, *args):
19991998
# register a callback that re-subscribes to any channels we
20001999
# were listening to when we were disconnected
20012000
self.connection.register_connect_callback(self.on_connect)
2002-
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
2001+
if self.push_handler_func is not None:
20032002
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
20042003
self._event_dispatcher.dispatch(
20052004
AfterPubSubConnectionInstantiationEvent(

redis/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def read_response(
636636
host_error = self._host_error()
637637

638638
try:
639-
if self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
639+
if self.protocol in ["3", 3]:
640640
response = self._parser.read_response(
641641
disable_decoding=disable_decoding, push_request=push_request
642642
)

redis/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
import hiredis # noqa
1313

1414
# Only support Hiredis >= 3.0:
15-
HIREDIS_AVAILABLE = int(hiredis.__version__.split(".")[0]) >= 3
15+
hiredis_version = hiredis.__version__.split(".")
16+
HIREDIS_AVAILABLE = int(hiredis_version[0]) >= 3 and int(hiredis_version[1]) >= 2
1617
if not HIREDIS_AVAILABLE:
17-
raise ImportError("hiredis package should be >= 3.0.0")
18+
raise ImportError("hiredis package should be >= 3.2.0")
1819
except ImportError:
1920
HIREDIS_AVAILABLE = False
2021

0 commit comments

Comments
 (0)