11from  abc  import  ABC , abstractmethod 
2+ from  enum  import  Enum 
23from  typing  import  List , Union , Optional 
34
45from  redis .credentials  import  StreamingCredentialProvider , CredentialProvider 
@@ -13,6 +14,15 @@ def listen(self, event: object):
1314        pass 
1415
1516
17+ class  AsyncEventListenerInterface (ABC ):
18+     """ 
19+     Represents an async listener for given event object. 
20+     """ 
21+     @abstractmethod  
22+     async  def  listen (self , event : object ):
23+         pass 
24+ 
25+ 
1626class  EventDispatcherInterface (ABC ):
1727    """ 
1828    Represents a dispatcher that dispatches events to listeners associated with given event. 
@@ -21,6 +31,10 @@ class EventDispatcherInterface(ABC):
2131    def  dispatch (self , event : object ):
2232        pass 
2333
34+     @abstractmethod  
35+     async  def  dispatch_async (self , event : object ):
36+         pass 
37+ 
2438
2539class  EventDispatcher (EventDispatcherInterface ):
2640    # TODO: Make dispatcher to accept external mappings. 
@@ -34,7 +48,10 @@ def __init__(self):
3448            ],
3549            AfterPooledConnectionsInstantiationEvent : [
3650                RegisterReAuthForPooledConnections ()
37-             ]
51+             ],
52+             AsyncBeforeCommandExecutionEvent : [
53+                 AsyncReAuthBeforeCommandExecutionListener (),
54+             ],
3855        }
3956
4057    def  dispatch (self , event : object ):
@@ -43,6 +60,12 @@ def dispatch(self, event: object):
4360        for  listener  in  listeners :
4461            listener .listen (event )
4562
63+     async  def  dispatch_async (self , event : object ):
64+         listeners  =  self ._event_listeners_mapping .get (type (event ))
65+ 
66+         for  listener  in  listeners :
67+             await  listener .listen (event )
68+ 
4669
4770class  BeforeCommandExecutionEvent :
4871    """ 
@@ -71,6 +94,10 @@ def credential_provider(self) -> StreamingCredentialProvider:
7194        return  self ._credential_provider 
7295
7396
97+ class  AsyncBeforeCommandExecutionEvent (BeforeCommandExecutionEvent ):
98+     pass 
99+ 
100+ 
74101class  AfterPooledConnectionsInstantiationEvent :
75102    """ 
76103    Event that will be fired after pooled connection instances was created. 
@@ -111,14 +138,33 @@ def listen(self, event: BeforeCommandExecutionEvent):
111138            event .connection .read_response ()
112139
113140
114- class  RegisterReAuthForPooledConnections (EventListenerInterface ):
141+ class  AsyncReAuthBeforeCommandExecutionListener (AsyncEventListenerInterface ):
142+     """ 
143+     Async listener that performs re-authentication (if needed) for StreamingCredentialProviders before command execution 
144+     """ 
115145    def  __init__ (self ):
116-         self ._event  =  None 
146+         self ._current_cred  =  None 
147+ 
148+     async  def  listen (self , event : AsyncBeforeCommandExecutionEvent ):
149+         if  self ._current_cred  is  None :
150+             self ._current_cred  =  event .initial_cred 
151+ 
152+         credentials  =  await  event .credential_provider .get_credentials_async ()
117153
154+         if  hash (credentials ) !=  self ._current_cred :
155+             self ._current_cred  =  hash (credentials )
156+             await  event .connection .send_command ('AUTH' , credentials [0 ], credentials [1 ])
157+             await  event .connection .read_response ()
158+ 
159+ 
160+ class  RegisterReAuthForPooledConnections (EventListenerInterface ):
118161    """ 
119162    Listener that registers a re-authentication callback for pooled connections. 
120163    Required by :class:`StreamingCredentialProvider`. 
121164    """ 
165+     def  __init__ (self ):
166+         self ._event  =  None 
167+ 
122168    def  listen (self , event : AfterPooledConnectionsInstantiationEvent ):
123169        if  isinstance (event .credential_provider , StreamingCredentialProvider ):
124170            self ._event  =  event 
0 commit comments