|
23 | 23 |
|
24 | 24 | from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser |
25 | 25 | from .backoff import NoBackoff |
26 | | -from .credentials import CredentialProvider, UsernamePasswordCredentialProvider |
| 26 | +from .credentials import CredentialProvider, UsernamePasswordCredentialProvider, StreamingCredentialProvider |
| 27 | +from .event import EventDispatcherInterface, EventDispatcher, BeforeCommandExecutionEvent |
27 | 28 | from .exceptions import ( |
28 | 29 | AuthenticationError, |
29 | 30 | AuthenticationWrongNumberOfArgsError, |
@@ -229,6 +230,7 @@ def __init__( |
229 | 230 | credential_provider: Optional[CredentialProvider] = None, |
230 | 231 | protocol: Optional[int] = 2, |
231 | 232 | command_packer: Optional[Callable[[], None]] = None, |
| 233 | + event_dispatcher: Optional[EventDispatcherInterface] = EventDispatcher() |
232 | 234 | ): |
233 | 235 | """ |
234 | 236 | Initialize a new Connection. |
@@ -283,6 +285,8 @@ def __init__( |
283 | 285 | self.set_parser(parser_class) |
284 | 286 | self._connect_callbacks = [] |
285 | 287 | self._buffer_cutoff = 6000 |
| 288 | + self._event_dispatcher = event_dispatcher |
| 289 | + self._init_auth_args = None |
286 | 290 | try: |
287 | 291 | p = int(protocol) |
288 | 292 | except TypeError: |
@@ -408,6 +412,7 @@ def on_connect(self): |
408 | 412 | or UsernamePasswordCredentialProvider(self.username, self.password) |
409 | 413 | ) |
410 | 414 | auth_args = cred_provider.get_credentials() |
| 415 | + self._init_auth_args = hash(auth_args) |
411 | 416 |
|
412 | 417 | # if resp version is specified and we have auth args, |
413 | 418 | # we need to send them via HELLO |
@@ -553,6 +558,10 @@ def send_packed_command(self, command, check_health=True): |
553 | 558 |
|
554 | 559 | def send_command(self, *args, **kwargs): |
555 | 560 | """Pack and send a command to the Redis server""" |
| 561 | + if isinstance(self.credential_provider, StreamingCredentialProvider): |
| 562 | + self._event_dispatcher.dispatch( |
| 563 | + BeforeCommandExecutionEvent(args, self._init_auth_args, self, self.credential_provider) |
| 564 | + ) |
556 | 565 | self.send_packed_command( |
557 | 566 | self._command_packer.pack(*args), |
558 | 567 | check_health=kwargs.get("check_health", True), |
@@ -1318,6 +1327,12 @@ def __init__( |
1318 | 1327 | connection_kwargs.pop("cache", None) |
1319 | 1328 | connection_kwargs.pop("cache_config", None) |
1320 | 1329 |
|
| 1330 | + cred_provider = connection_kwargs.get("credential_provider") |
| 1331 | + |
| 1332 | + if cred_provider is not None and isinstance(cred_provider, StreamingCredentialProvider): |
| 1333 | + cred_provider.on_next(self._re_auth) |
| 1334 | + |
| 1335 | + |
1321 | 1336 | # a lock to protect the critical section in _checkpid(). |
1322 | 1337 | # this lock is acquired when the process id changes, such as |
1323 | 1338 | # after a fork. during this time, multiple threads in the child |
@@ -1517,6 +1532,14 @@ def set_retry(self, retry: "Retry") -> None: |
1517 | 1532 | for conn in self._in_use_connections: |
1518 | 1533 | conn.retry = retry |
1519 | 1534 |
|
| 1535 | + def _re_auth(self, token): |
| 1536 | + with self._lock: |
| 1537 | + for conn in self._available_connections: |
| 1538 | + conn.send_command( |
| 1539 | + 'AUTH', token.try_get('oid'), token.get_value() |
| 1540 | + ) |
| 1541 | + conn.read_response() |
| 1542 | + |
1520 | 1543 |
|
1521 | 1544 | class BlockingConnectionPool(ConnectionPool): |
1522 | 1545 | """ |
|
0 commit comments