Skip to content

Commit 8272c73

Browse files
committed
StreamingCredentialProvider support
1 parent 0a8f770 commit 8272c73

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

redis/connection.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
2525
from .backoff import NoBackoff
26-
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
26+
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider, StreamingCredentialProvider
27+
from .event import EventDispatcherInterface, EventDispatcher, BeforeCommandExecutionEvent
2728
from .exceptions import (
2829
AuthenticationError,
2930
AuthenticationWrongNumberOfArgsError,
@@ -229,6 +230,7 @@ def __init__(
229230
credential_provider: Optional[CredentialProvider] = None,
230231
protocol: Optional[int] = 2,
231232
command_packer: Optional[Callable[[], None]] = None,
233+
event_dispatcher: Optional[EventDispatcherInterface] = EventDispatcher()
232234
):
233235
"""
234236
Initialize a new Connection.
@@ -283,6 +285,8 @@ def __init__(
283285
self.set_parser(parser_class)
284286
self._connect_callbacks = []
285287
self._buffer_cutoff = 6000
288+
self._event_dispatcher = event_dispatcher
289+
self._init_auth_args = None
286290
try:
287291
p = int(protocol)
288292
except TypeError:
@@ -408,6 +412,7 @@ def on_connect(self):
408412
or UsernamePasswordCredentialProvider(self.username, self.password)
409413
)
410414
auth_args = cred_provider.get_credentials()
415+
self._init_auth_args = hash(auth_args)
411416

412417
# if resp version is specified and we have auth args,
413418
# we need to send them via HELLO
@@ -553,6 +558,10 @@ def send_packed_command(self, command, check_health=True):
553558

554559
def send_command(self, *args, **kwargs):
555560
"""Pack and send a command to the Redis server"""
561+
if isinstance(self.credential_provider, StreamingCredentialProvider):
562+
self._event_dispatcher.dispatch(
563+
BeforeCommandExecutionEvent(args, self._init_auth_args, self, self.credential_provider)
564+
)
556565
self.send_packed_command(
557566
self._command_packer.pack(*args),
558567
check_health=kwargs.get("check_health", True),
@@ -1318,6 +1327,12 @@ def __init__(
13181327
connection_kwargs.pop("cache", None)
13191328
connection_kwargs.pop("cache_config", None)
13201329

1330+
cred_provider = connection_kwargs.get("credential_provider")
1331+
1332+
if cred_provider is not None and isinstance(cred_provider, StreamingCredentialProvider):
1333+
cred_provider.on_next(self._re_auth)
1334+
1335+
13211336
# a lock to protect the critical section in _checkpid().
13221337
# this lock is acquired when the process id changes, such as
13231338
# after a fork. during this time, multiple threads in the child
@@ -1517,6 +1532,14 @@ def set_retry(self, retry: "Retry") -> None:
15171532
for conn in self._in_use_connections:
15181533
conn.retry = retry
15191534

1535+
def _re_auth(self, token):
1536+
with self._lock:
1537+
for conn in self._available_connections:
1538+
conn.send_command(
1539+
'AUTH', token.try_get('oid'), token.get_value()
1540+
)
1541+
conn.read_response()
1542+
15201543

15211544
class BlockingConnectionPool(ConnectionPool):
15221545
"""

redis/credentials.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,21 @@ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
1212

1313

1414
class StreamingCredentialProvider(CredentialProvider, ABC):
15+
"""
16+
Credential provider that streams credentials in the background.
17+
"""
1518
@abstractmethod
1619
def on_next(self, callback: Callable[[Any], None]):
20+
"""
21+
Specifies the callback that should be invoked when the next credentials will be retrieved.
22+
23+
:param callback: Callback with
24+
:return:
25+
"""
26+
pass
27+
28+
@abstractmethod
29+
def on_error(self, callback: Callable[[Exception], None]):
1730
pass
1831

1932
@abstractmethod

redis/event.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from abc import ABC, abstractmethod
2+
3+
from redis.credentials import StreamingCredentialProvider
4+
5+
6+
class EventListenerInterface(ABC):
7+
"""
8+
Represents a listener for given event object.
9+
"""
10+
@abstractmethod
11+
def listen(self, event: object):
12+
pass
13+
14+
15+
class EventDispatcherInterface(ABC):
16+
"""
17+
Represents a dispatcher that dispatches events to listeners associated with given event.
18+
"""
19+
@abstractmethod
20+
def dispatch(self, event: object):
21+
pass
22+
23+
24+
class EventDispatcher(EventDispatcherInterface):
25+
26+
def __init__(self):
27+
"""
28+
Mapping should be extended for any new events or listeners to be added.
29+
"""
30+
self._event_listeners_mapping = {
31+
BeforeCommandExecutionEvent: [
32+
ReAuthBeforeCommandExecutionListener(),
33+
]
34+
}
35+
36+
def dispatch(self, event: object):
37+
listeners = self._event_listeners_mapping.get(type(event))
38+
39+
for listener in listeners:
40+
listener.listen(event)
41+
42+
43+
class BeforeCommandExecutionEvent:
44+
"""
45+
Event that will be fired before each command execution.
46+
"""
47+
def __init__(self, command, initial_cred, connection, credential_provider: StreamingCredentialProvider):
48+
self._command = command
49+
self._initial_cred = initial_cred
50+
self._credential_provider = credential_provider
51+
self._connection = connection
52+
53+
@property
54+
def command(self):
55+
return self._command
56+
57+
@property
58+
def initial_cred(self):
59+
return self._initial_cred
60+
61+
@property
62+
def connection(self):
63+
return self._connection
64+
65+
@property
66+
def credential_provider(self) -> StreamingCredentialProvider:
67+
return self._credential_provider
68+
69+
70+
class ReAuthBeforeCommandExecutionListener(EventListenerInterface):
71+
"""
72+
Listener that performs re-authentication (if needed) for StreamingCredentialProviders before command execution.
73+
"""
74+
def __init__(self):
75+
self._current_cred = None
76+
77+
def listen(self, event: BeforeCommandExecutionEvent):
78+
if self._current_cred is None:
79+
self._current_cred = event.initial_cred
80+
81+
credentials = event.credential_provider.get_credentials()
82+
83+
print(hash(credentials) != self._current_cred)
84+
85+
if hash(credentials) != self._current_cred:
86+
self._current_cred = hash(credentials)
87+
event.connection.send_command('AUTH', credentials[0], credentials[1])
88+
event.connection.read_response()

0 commit comments

Comments
 (0)