Skip to content

Commit 6dae71b

Browse files
committed
Added support for single connection client
1 parent 0327f36 commit 6dae71b

File tree

4 files changed

+85
-2
lines changed

4 files changed

+85
-2
lines changed

redis/asyncio/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
list_or_args,
5454
)
5555
from redis.credentials import CredentialProvider
56-
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType
56+
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType, \
57+
AfterSingleConnectionInstantiationEvent
5758
from redis.exceptions import (
5859
ConnectionError,
5960
ExecAbortError,
@@ -337,6 +338,7 @@ def __init__(
337338
))
338339

339340
self.connection_pool = connection_pool
341+
self._event_dispatcher = event_dispatcher
340342
self.single_connection_client = single_connection_client
341343
self.connection: Optional[Connection] = None
342344

@@ -366,6 +368,10 @@ async def initialize(self: _RedisT) -> _RedisT:
366368
async with self._single_conn_lock:
367369
if self.connection is None:
368370
self.connection = await self.connection_pool.get_connection("_")
371+
372+
self._event_dispatcher.dispatch(
373+
AfterSingleConnectionInstantiationEvent(self.connection, ClientType.ASYNC, self._single_conn_lock)
374+
)
369375
return self
370376

371377
def set_response_callback(self, command: str, callback: ResponseCallbackT):

redis/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import copy
23
import re
34
import threading
@@ -27,7 +28,8 @@
2728
UnixDomainSocketConnection,
2829
)
2930
from redis.credentials import CredentialProvider
30-
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType
31+
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType, \
32+
AfterSingleConnectionInstantiationEvent
3133
from redis.exceptions import (
3234
ConnectionError,
3335
ExecAbortError,
@@ -337,9 +339,13 @@ def __init__(
337339
]:
338340
raise RedisError("Client caching is only supported with RESP version 3")
339341

342+
self._connection_lock = threading.Lock()
340343
self.connection = None
341344
if single_connection_client:
342345
self.connection = self.connection_pool.get_connection("_")
346+
event_dispatcher.dispatch(
347+
AfterSingleConnectionInstantiationEvent(self.connection, ClientType.SYNC, self._connection_lock)
348+
)
343349

344350
self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
345351

redis/connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ def on_connect(self):
413413
)
414414
auth_args = cred_provider.get_credentials()
415415
self._init_auth_args = hash(auth_args)
416+
print(auth_args)
416417

417418
# if resp version is specified and we have auth args,
418419
# we need to send them via HELLO

redis/event.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import threading
13
from abc import ABC, abstractmethod
24
from enum import Enum
35
from typing import List, Union, Optional
@@ -9,6 +11,7 @@ class EventListenerInterface(ABC):
911
"""
1012
Represents a listener for given event object.
1113
"""
14+
1215
@abstractmethod
1316
def listen(self, event: object):
1417
pass
@@ -18,6 +21,7 @@ class AsyncEventListenerInterface(ABC):
1821
"""
1922
Represents an async listener for given event object.
2023
"""
24+
2125
@abstractmethod
2226
async def listen(self, event: object):
2327
pass
@@ -27,6 +31,7 @@ class EventDispatcherInterface(ABC):
2731
"""
2832
Represents a dispatcher that dispatches events to listeners associated with given event.
2933
"""
34+
3035
@abstractmethod
3136
def dispatch(self, event: object):
3237
pass
@@ -49,6 +54,9 @@ def __init__(self):
4954
AfterPooledConnectionsInstantiationEvent: [
5055
RegisterReAuthForPooledConnections()
5156
],
57+
AfterSingleConnectionInstantiationEvent: [
58+
RegisterReAuthForSingleConnection()
59+
],
5260
AsyncBeforeCommandExecutionEvent: [
5361
AsyncReAuthBeforeCommandExecutionListener(),
5462
],
@@ -71,6 +79,7 @@ class BeforeCommandExecutionEvent:
7179
"""
7280
Event that will be fired before each command execution.
7381
"""
82+
7483
def __init__(self, command, initial_cred, connection, credential_provider: StreamingCredentialProvider):
7584
self._command = command
7685
self._initial_cred = initial_cred
@@ -107,6 +116,7 @@ class AfterPooledConnectionsInstantiationEvent:
107116
"""
108117
Event that will be fired after pooled connection instances was created.
109118
"""
119+
110120
def __init__(
111121
self,
112122
connection_pools: List,
@@ -130,10 +140,40 @@ def credential_provider(self) -> Union[CredentialProvider, None]:
130140
return self._credential_provider
131141

132142

143+
class AfterSingleConnectionInstantiationEvent:
144+
"""
145+
Event that will be fired after single connection instances was created.
146+
147+
:param connection_lock: For sync client thread-lock should be provided, for async asyncio.Lock
148+
"""
149+
def __init__(
150+
self,
151+
connection,
152+
client_type: ClientType,
153+
connection_lock: Union[threading.Lock, asyncio.Lock]
154+
):
155+
self._connection = connection
156+
self._client_type = client_type
157+
self._connection_lock = connection_lock
158+
159+
@property
160+
def connection(self):
161+
return self._connection
162+
163+
@property
164+
def client_type(self) -> ClientType:
165+
return self._client_type
166+
167+
@property
168+
def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]:
169+
return self._connection_lock
170+
171+
133172
class ReAuthBeforeCommandExecutionListener(EventListenerInterface):
134173
"""
135174
Listener that performs re-authentication (if needed) for StreamingCredentialProviders before command execution.
136175
"""
176+
137177
def __init__(self):
138178
self._current_cred = None
139179

@@ -156,6 +196,7 @@ class AsyncReAuthBeforeCommandExecutionListener(AsyncEventListenerInterface):
156196
"""
157197
Async listener that performs re-authentication (if needed) for StreamingCredentialProviders before command execution
158198
"""
199+
159200
def __init__(self):
160201
self._current_cred = None
161202

@@ -176,6 +217,7 @@ class RegisterReAuthForPooledConnections(EventListenerInterface):
176217
Listener that registers a re-authentication callback for pooled connections.
177218
Required by :class:`StreamingCredentialProvider`.
178219
"""
220+
179221
def __init__(self):
180222
self._event = None
181223

@@ -195,3 +237,31 @@ def _re_auth(self, token):
195237
async def _re_auth_async(self, token):
196238
for pool in self._event.connection_pools:
197239
await pool.re_auth_callback(token)
240+
241+
242+
class RegisterReAuthForSingleConnection(EventListenerInterface):
243+
"""
244+
Listener that registers a re-authentication callback for single connection.
245+
Required by :class:`StreamingCredentialProvider`.
246+
"""
247+
def __init__(self):
248+
self._event = None
249+
250+
def listen(self, event: AfterSingleConnectionInstantiationEvent):
251+
if isinstance(event.connection.credential_provider, StreamingCredentialProvider):
252+
self._event = event
253+
254+
if event.client_type == ClientType.SYNC:
255+
event.connection.credential_provider.on_next(self._re_auth)
256+
else:
257+
event.connection.credential_provider.on_next(self._re_auth_async)
258+
259+
def _re_auth(self, token):
260+
with self._event.connection_lock:
261+
self._event.connection.send_command('AUTH', token.try_get('oid'), token.get_value())
262+
self._event.connection.read_response()
263+
264+
async def _re_auth_async(self, token):
265+
async with self._event.connection_lock:
266+
await self._event.connection.send_command('AUTH', token.try_get('oid'), token.get_value())
267+
await self._event.connection.read_response()

0 commit comments

Comments
 (0)