Skip to content
9 changes: 8 additions & 1 deletion redis/_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .base import BaseParser, _AsyncRESPBase
from .base import (
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
_AsyncRESPBase,
)
from .commands import AsyncCommandsParser, CommandsParser
from .encoders import Encoder
from .hiredis import _AsyncHiredisParser, _HiredisParser
Expand All @@ -11,10 +16,12 @@
"_AsyncRESPBase",
"_AsyncRESP2Parser",
"_AsyncRESP3Parser",
"AsyncPushNotificationsParser",
"CommandsParser",
"Encoder",
"BaseParser",
"_HiredisParser",
"_RESP2Parser",
"_RESP3Parser",
"PushNotificationsParser",
]
54 changes: 53 additions & 1 deletion redis/_parsers/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from abc import ABC
from asyncio import IncompleteReadError, StreamReader, TimeoutError
from typing import List, Optional, Union
from typing import Callable, List, Optional, Protocol, Union

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
Expand Down Expand Up @@ -158,6 +158,58 @@ async def read_response(
raise NotImplementedError()


_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]


class PushNotificationsParser(Protocol):
"""Protocol defining RESP3-specific parsing functionality"""

pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None

def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses"""
raise NotImplementedError()

def handle_push_response(self, response, **kwargs):
if response[0] not in _INVALIDATION_MESSAGE:
return self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return self.invalidation_push_handler_func(response)

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func


class AsyncPushNotificationsParser(Protocol):
"""Protocol defining async RESP3-specific parsing functionality"""

pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None

async def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses asynchronously"""
...

async def handle_push_response(self, response, **kwargs):
"""Handle push responses asynchronously"""
if response[0] not in _INVALIDATION_MESSAGE:
return await self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return await self.invalidation_push_handler_func(response)

def set_pubsub_push_handler(self, pubsub_push_handler_func):
"""Set the pubsub push handler function"""
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidation_push_handler_func):
"""Set the invalidation push handler function"""
self.invalidation_push_handler_func = invalidation_push_handler_func


class _AsyncRESPBase(AsyncBaseParser):
"""Base class for async resp parsing"""

Expand Down
77 changes: 72 additions & 5 deletions redis/_parsers/hiredis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import socket
import sys
from logging import getLogger
from typing import Callable, List, Optional, TypedDict, Union

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
Expand All @@ -11,7 +12,12 @@
from ..exceptions import ConnectionError, InvalidResponse, RedisError
from ..typing import EncodableT
from ..utils import HIREDIS_AVAILABLE
from .base import AsyncBaseParser, BaseParser
from .base import (
AsyncBaseParser,
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
)
from .socket import (
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
NONBLOCKING_EXCEPTIONS,
Expand All @@ -32,21 +38,29 @@ class _HiredisReaderArgs(TypedDict, total=False):
errors: Optional[str]


class _HiredisParser(BaseParser):
class _HiredisParser(BaseParser, PushNotificationsParser):
"Parser class for connections using Hiredis"

def __init__(self, socket_read_size):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
self.socket_read_size = socket_read_size
self._buffer = bytearray(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None

def __del__(self):
try:
self.on_disconnect()
except Exception:
pass

def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response

def on_connect(self, connection, **kwargs):
import hiredis

Expand All @@ -64,6 +78,12 @@ def on_connect(self, connection, **kwargs):
self._reader = hiredis.Reader(**kwargs)
self._next_response = NOT_ENOUGH_DATA

try:
self._hiredis_PushNotificationType = hiredis.PushNotification
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None

def on_disconnect(self):
self._sock = None
self._reader = None
Expand Down Expand Up @@ -109,14 +129,24 @@ def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
if custom_timeout:
sock.settimeout(self._socket_timeout)

def read_response(self, disable_decoding=False):
def read_response(self, disable_decoding=False, push_request=False):
if not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)

# _next_response might be cached from a can_read() call
if self._next_response is not NOT_ENOUGH_DATA:
response = self._next_response
self._next_response = NOT_ENOUGH_DATA
if self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
return response

if disable_decoding:
Expand All @@ -135,6 +165,16 @@ def read_response(self, disable_decoding=False):
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
elif (
isinstance(response, list)
and response
Expand All @@ -144,7 +184,7 @@ def read_response(self, disable_decoding=False):
return response


class _AsyncHiredisParser(AsyncBaseParser):
class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
"""Async implementation of parser class for connections using Hiredis"""

__slots__ = ("_reader",)
Expand All @@ -154,6 +194,14 @@ def __init__(self, socket_read_size: int):
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader = None
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None

async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response

def on_connect(self, connection):
import hiredis
Expand All @@ -171,6 +219,14 @@ def on_connect(self, connection):
self._reader = hiredis.Reader(**kwargs)
self._connected = True

try:
self._hiredis_PushNotificationType = getattr(
hiredis, "PushNotification", None
)
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None

def on_disconnect(self):
self._connected = False

Expand All @@ -195,7 +251,7 @@ async def read_from_socket(self):
return True

async def read_response(
self, disable_decoding: bool = False
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, List[EncodableT]]:
# If `on_disconnect()` has been called, prohibit any more reads
# even if they could happen because data might be present.
Expand All @@ -207,6 +263,7 @@ async def read_response(
response = self._reader.gets(False)
else:
response = self._reader.gets()

while response is NOT_ENOUGH_DATA:
await self.read_from_socket()
if disable_decoding:
Expand All @@ -219,6 +276,16 @@ async def read_response(
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = await self.handle_push_response(response)
if not push_request:
return await self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
elif (
isinstance(response, list)
and response
Expand Down
45 changes: 10 additions & 35 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from .base import _AsyncRESPBase, _RESPBase
from .base import (
AsyncPushNotificationsParser,
PushNotificationsParser,
_AsyncRESPBase,
_RESPBase,
)
from .socket import SERVER_CLOSED_CONNECTION_ERROR

_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]


class _RESP3Parser(_RESPBase):
class _RESP3Parser(_RESPBase, PushNotificationsParser):
"""RESP3 protocol implementation"""

def __init__(self, socket_read_size):
Expand Down Expand Up @@ -113,9 +116,7 @@ def _read_response(self, disable_decoding=False, push_request=False):
)
for _ in range(int(response))
]
response = self.handle_push_response(
response, disable_decoding, push_request
)
response = self.handle_push_response(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
Expand All @@ -129,20 +130,8 @@ def _read_response(self, disable_decoding=False, push_request=False):
response = self.encoder.decode(response)
return response

def handle_push_response(self, response, disable_decoding, push_request):
if response[0] not in _INVALIDATION_MESSAGE:
return self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return self.invalidation_push_handler_func(response)

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func


class _AsyncRESP3Parser(_AsyncRESPBase):
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
Expand Down Expand Up @@ -253,9 +242,7 @@ async def _read_response(
)
for _ in range(int(response))
]
response = await self.handle_push_response(
response, disable_decoding, push_request
)
response = await self.handle_push_response(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
Expand All @@ -268,15 +255,3 @@ async def _read_response(
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

async def handle_push_response(self, response, disable_decoding, push_request):
if response[0] not in _INVALIDATION_MESSAGE:
return await self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return await self.invalidation_push_handler_func(response)

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func
3 changes: 1 addition & 2 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import (
HIREDIS_AVAILABLE,
SSL_AVAILABLE,
_set_info_logger,
deprecated_args,
Expand Down Expand Up @@ -938,7 +937,7 @@ async def connect(self):
self.connection.register_connect_callback(self.on_connect)
else:
await self.connection.connect()
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
if self.push_handler_func is not None:
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)

self._event_dispatcher.dispatch(
Expand Down
8 changes: 2 additions & 6 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,11 +576,7 @@ async def read_response(
read_timeout = timeout if timeout is not None else self.socket_timeout
host_error = self._host_error()
try:
if (
read_timeout is not None
and self.protocol in ["3", 3]
and not HIREDIS_AVAILABLE
):
if read_timeout is not None and self.protocol in ["3", 3]:
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
Expand All @@ -590,7 +586,7 @@ async def read_response(
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
elif self.protocol in ["3", 3]:
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
Expand Down
Loading
Loading