@@ -760,59 +760,56 @@ def check_health(self):
760
760
self ._conn .check_health ()
761
761
762
762
def send_packed_command (self , command , check_health = True ):
763
- cache_key = hashkey (command )
764
-
765
- if self ._cache .get (cache_key ):
766
- self ._current_command_hash = cache_key
767
- return
768
-
769
- self ._current_command_hash = None
763
+ self ._process_pending_invalidations ()
764
+ # TODO: Investigate if it's possible to unpack command or extract keys from packed command
770
765
self ._conn .send_packed_command (command )
771
766
772
767
def send_command (self , * args , ** kwargs ):
768
+ self ._process_pending_invalidations ()
769
+
770
+ # If command is write command or not allowed to cache skip it.
773
771
if not self ._conf .is_allowed_to_cache (args [0 ]):
774
772
self ._current_command_hash = None
775
773
self ._current_command_keys = None
776
774
self ._conn .send_command (* args , ** kwargs )
777
775
return
778
776
777
+ # Create hash representation of current executed command.
779
778
self ._current_command_hash = hashkey (* args )
780
779
780
+ # Extract keys from current command.
781
781
if kwargs .get ("keys" ):
782
782
self ._current_command_keys = kwargs ["keys" ]
783
783
784
784
if not isinstance (self ._current_command_keys , list ):
785
785
raise TypeError ("Cache keys must be a list." )
786
786
787
+ # If current command reply already cached prevent sending data over socket.
787
788
if self ._cache .get (self ._current_command_hash ):
788
789
return
789
790
791
+ # Send command over socket only if it's read-only command that not yet cached.
790
792
self ._conn .send_command (* args , ** kwargs )
791
793
792
794
def can_read (self , timeout = 0 ):
793
795
return self ._conn .can_read (timeout )
794
796
795
797
def read_response (self , disable_decoding = False , * , disconnect_on_error = True , push_request = False ):
798
+ # Check if command response exists in a cache.
799
+ if self ._current_command_hash in self ._cache :
800
+ return self ._cache [self ._current_command_hash ]
801
+
796
802
response = self ._conn .read_response (
797
803
disable_decoding = disable_decoding ,
798
804
disconnect_on_error = disconnect_on_error ,
799
805
push_request = push_request
800
806
)
801
807
802
- if isinstance (response , List ) and len (response ) > 0 and response [0 ] == 'invalidate' :
803
- self ._on_invalidation_callback (response )
804
- self .read_response (
805
- disable_decoding = disable_decoding ,
806
- disconnect_on_error = disconnect_on_error ,
807
- push_request = push_request
808
- )
809
-
808
+ # Check if command that was sent is write command to prevent caching of write replies.
810
809
if response is None or self ._current_command_hash is None :
811
810
return response
812
811
813
- if self ._current_command_hash in self ._cache :
814
- return self ._cache [self ._current_command_hash ]
815
-
812
+ # Create separate mapping for keys or add current response to associated keys.
816
813
for key in self ._current_command_keys :
817
814
if key in self ._keys_mapping :
818
815
if self ._current_command_hash not in self ._keys_mapping [key ]:
@@ -824,10 +821,10 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus
824
821
return response
825
822
826
823
def pack_command (self , * args ):
827
- pass
824
+ return self . _conn . pack_command ( * args )
828
825
829
826
def pack_commands (self , commands ):
830
- pass
827
+ return self . _conn . pack_commands ( commands )
831
828
832
829
def _connect (self ):
833
830
self ._conn ._connect ()
@@ -838,24 +835,27 @@ def _host_error(self):
838
835
def _enable_tracking_callback (self , conn : ConnectionInterface ) -> None :
839
836
conn .send_command ('CLIENT' , 'TRACKING' , 'ON' )
840
837
conn .read_response ()
838
+ conn ._parser .set_invalidation_push_handler (self ._on_invalidation_callback )
841
839
842
840
def _process_pending_invalidations (self ):
843
- print (f'connection { self } { id (self )} process invalidations' )
844
841
while self .can_read ():
845
- self .read_response (push_request = True )
842
+ self ._conn . read_response (push_request = True )
846
843
847
844
def _on_invalidation_callback (
848
845
self , data : List [Union [str , Optional [List [str ]]]]
849
846
):
847
+ # Flush cache when DB flushed on server-side
850
848
if data [1 ] is None :
851
849
self ._cache .clear ()
852
850
else :
853
851
for key in data [1 ]:
854
852
normalized_key = ensure_string (key )
855
853
if normalized_key in self ._keys_mapping :
854
+ # Make sure that all command responses associated with this key will be deleted
856
855
for cache_key in self ._keys_mapping [normalized_key ]:
857
856
self ._cache .pop (cache_key )
858
-
857
+ # Removes key from mapping cache
858
+ self ._keys_mapping .pop (normalized_key )
859
859
860
860
861
861
class SSLConnection (Connection ):
@@ -1235,7 +1235,6 @@ def __init__(
1235
1235
connection_kwargs .pop ("cache_ttl" , None )
1236
1236
connection_kwargs .pop ("cache" , None )
1237
1237
1238
-
1239
1238
# a lock to protect the critical section in _checkpid().
1240
1239
# this lock is acquired when the process id changes, such as
1241
1240
# after a fork. during this time, multiple threads in the child
@@ -1343,7 +1342,7 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection":
1343
1342
# pool before all data has been read or the socket has been
1344
1343
# closed. either way, reconnect and verify everything is good.
1345
1344
try :
1346
- if connection .can_read ():
1345
+ if connection .can_read () and self . _cache is None :
1347
1346
raise ConnectionError ("Connection has data" )
1348
1347
except (ConnectionError , OSError ):
1349
1348
connection .disconnect ()
0 commit comments