|
23 | 23 |
|
24 | 24 | from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
|
25 | 25 | from .backoff import NoBackoff
|
26 |
| -from .credentials import CredentialProvider, UsernamePasswordCredentialProvider |
| 26 | +from .credentials import CredentialProvider, UsernamePasswordCredentialProvider, StreamingCredentialProvider |
| 27 | +from .event import EventDispatcherInterface, EventDispatcher, BeforeCommandExecutionEvent |
27 | 28 | from .exceptions import (
|
28 | 29 | AuthenticationError,
|
29 | 30 | AuthenticationWrongNumberOfArgsError,
|
@@ -229,6 +230,7 @@ def __init__(
|
229 | 230 | credential_provider: Optional[CredentialProvider] = None,
|
230 | 231 | protocol: Optional[int] = 2,
|
231 | 232 | command_packer: Optional[Callable[[], None]] = None,
|
| 233 | + event_dispatcher: Optional[EventDispatcherInterface] = EventDispatcher() |
232 | 234 | ):
|
233 | 235 | """
|
234 | 236 | Initialize a new Connection.
|
@@ -283,6 +285,8 @@ def __init__(
|
283 | 285 | self.set_parser(parser_class)
|
284 | 286 | self._connect_callbacks = []
|
285 | 287 | self._buffer_cutoff = 6000
|
| 288 | + self._event_dispatcher = event_dispatcher |
| 289 | + self._init_auth_args = None |
286 | 290 | try:
|
287 | 291 | p = int(protocol)
|
288 | 292 | except TypeError:
|
@@ -408,6 +412,7 @@ def on_connect(self):
|
408 | 412 | or UsernamePasswordCredentialProvider(self.username, self.password)
|
409 | 413 | )
|
410 | 414 | auth_args = cred_provider.get_credentials()
|
| 415 | + self._init_auth_args = hash(auth_args) |
411 | 416 |
|
412 | 417 | # if resp version is specified and we have auth args,
|
413 | 418 | # we need to send them via HELLO
|
@@ -553,6 +558,10 @@ def send_packed_command(self, command, check_health=True):
|
553 | 558 |
|
554 | 559 | def send_command(self, *args, **kwargs):
|
555 | 560 | """Pack and send a command to the Redis server"""
|
| 561 | + if isinstance(self.credential_provider, StreamingCredentialProvider): |
| 562 | + self._event_dispatcher.dispatch( |
| 563 | + BeforeCommandExecutionEvent(args, self._init_auth_args, self, self.credential_provider) |
| 564 | + ) |
556 | 565 | self.send_packed_command(
|
557 | 566 | self._command_packer.pack(*args),
|
558 | 567 | check_health=kwargs.get("check_health", True),
|
@@ -1318,6 +1327,12 @@ def __init__(
|
1318 | 1327 | connection_kwargs.pop("cache", None)
|
1319 | 1328 | connection_kwargs.pop("cache_config", None)
|
1320 | 1329 |
|
| 1330 | + cred_provider = connection_kwargs.get("credential_provider") |
| 1331 | + |
| 1332 | + if cred_provider is not None and isinstance(cred_provider, StreamingCredentialProvider): |
| 1333 | + cred_provider.on_next(self._re_auth) |
| 1334 | + |
| 1335 | + |
1321 | 1336 | # a lock to protect the critical section in _checkpid().
|
1322 | 1337 | # this lock is acquired when the process id changes, such as
|
1323 | 1338 | # after a fork. during this time, multiple threads in the child
|
@@ -1517,6 +1532,14 @@ def set_retry(self, retry: "Retry") -> None:
|
1517 | 1532 | for conn in self._in_use_connections:
|
1518 | 1533 | conn.retry = retry
|
1519 | 1534 |
|
| 1535 | + def _re_auth(self, token): |
| 1536 | + with self._lock: |
| 1537 | + for conn in self._available_connections: |
| 1538 | + conn.send_command( |
| 1539 | + 'AUTH', token.try_get('oid'), token.get_value() |
| 1540 | + ) |
| 1541 | + conn.read_response() |
| 1542 | + |
1520 | 1543 |
|
1521 | 1544 | class BlockingConnectionPool(ConnectionPool):
|
1522 | 1545 | """
|
|
0 commit comments