6
6
from .base import _AsyncRESPBase , _RESPBase
7
7
from .socket import SERVER_CLOSED_CONNECTION_ERROR
8
8
9
+ _INVALIDATION_MESSAGE = [b"invalidate" , "invalidate" ]
10
+
9
11
10
12
class _RESP3Parser (_RESPBase ):
11
13
"""RESP3 protocol implementation"""
12
14
13
15
def __init__ (self , socket_read_size ):
14
16
super ().__init__ (socket_read_size )
15
- self .push_handler_func = self .handle_push_response
17
+ self .pubsub_push_handler_func = self .handle_pubsub_push_response
18
+ self .invalidations_push_handler_func = None
16
19
17
- def handle_push_response (self , response ):
20
+ def handle_pubsub_push_response (self , response ):
18
21
logger = getLogger ("push_response" )
19
22
logger .info ("Push response: " + str (response ))
20
23
return response
@@ -96,8 +99,9 @@ def _read_response(self, disable_decoding=False, push_request=False):
96
99
pass
97
100
# map response
98
101
elif byte == b"%" :
99
- # we use this approach and not dict comprehension here
100
- # because this dict comprehension fails in python 3.7
102
+ # We cannot use a dict-comprehension to parse stream.
103
+ # Evaluation order of key:val expression in dict comprehension only
104
+ # became defined to be left-right in version 3.8
101
105
resp_dict = {}
102
106
for _ in range (int (response )):
103
107
key = self ._read_response (disable_decoding = disable_decoding )
@@ -113,30 +117,40 @@ def _read_response(self, disable_decoding=False, push_request=False):
113
117
)
114
118
for _ in range (int (response ))
115
119
]
116
- res = self .push_handler_func (response )
117
- if not push_request :
118
- return self ._read_response (
119
- disable_decoding = disable_decoding , push_request = push_request
120
- )
121
- else :
122
- return res
120
+ self .handle_push_response (response , disable_decoding , push_request )
123
121
else :
124
122
raise InvalidResponse (f"Protocol Error: { raw !r} " )
125
123
126
124
if isinstance (response , bytes ) and disable_decoding is False :
127
125
response = self .encoder .decode (response )
128
126
return response
129
127
130
- def set_push_handler (self , push_handler_func ):
131
- self .push_handler_func = push_handler_func
128
+ def handle_push_response (self , response , disable_decoding , push_request ):
129
+ if response [0 ] in _INVALIDATION_MESSAGE :
130
+ res = self .invalidation_push_handler_func (response )
131
+ else :
132
+ res = self .pubsub_push_handler_func (response )
133
+ if not push_request :
134
+ return self ._read_response (
135
+ disable_decoding = disable_decoding , push_request = push_request
136
+ )
137
+ else :
138
+ return res
139
+
140
+ def set_pubsub_push_handler (self , pubsub_push_handler_func ):
141
+ self .pubsub_push_handler_func = pubsub_push_handler_func
142
+
143
+ def set_invalidation_push_handler (self , invalidations_push_handler_func ):
144
+ self .invalidation_push_handler_func = invalidations_push_handler_func
132
145
133
146
134
147
class _AsyncRESP3Parser (_AsyncRESPBase ):
135
148
def __init__ (self , socket_read_size ):
136
149
super ().__init__ (socket_read_size )
137
- self .push_handler_func = self .handle_push_response
150
+ self .pubsub_push_handler_func = self .handle_pubsub_push_response
151
+ self .invalidations_push_handler_func = None
138
152
139
- def handle_push_response (self , response ):
153
+ def handle_pubsub_push_response (self , response ):
140
154
logger = getLogger ("push_response" )
141
155
logger .info ("Push response: " + str (response ))
142
156
return response
@@ -225,12 +239,16 @@ async def _read_response(
225
239
pass
226
240
# map response
227
241
elif byte == b"%" :
228
- response = {
229
- (await self ._read_response (disable_decoding = disable_decoding )): (
230
- await self ._read_response (disable_decoding = disable_decoding )
242
+ # We cannot use a dict-comprehension to parse stream.
243
+ # Evaluation order of key:val expression in dict comprehension only
244
+ # became defined to be left-right in version 3.8
245
+ resp_dict = {}
246
+ for _ in range (int (response )):
247
+ key = await self ._read_response (disable_decoding = disable_decoding )
248
+ resp_dict [key ] = await self ._read_response (
249
+ disable_decoding = disable_decoding , push_request = push_request
231
250
)
232
- for _ in range (int (response ))
233
- }
251
+ response = resp_dict
234
252
# push response
235
253
elif byte == b">" :
236
254
response = [
@@ -241,19 +259,28 @@ async def _read_response(
241
259
)
242
260
for _ in range (int (response ))
243
261
]
244
- res = self .push_handler_func (response )
245
- if not push_request :
246
- return await self ._read_response (
247
- disable_decoding = disable_decoding , push_request = push_request
248
- )
249
- else :
250
- return res
262
+ await self .handle_push_response (response , disable_decoding , push_request )
251
263
else :
252
264
raise InvalidResponse (f"Protocol Error: { raw !r} " )
253
265
254
266
if isinstance (response , bytes ) and disable_decoding is False :
255
267
response = self .encoder .decode (response )
256
268
return response
257
269
258
- def set_push_handler (self , push_handler_func ):
259
- self .push_handler_func = push_handler_func
270
+ async def handle_push_response (self , response , disable_decoding , push_request ):
271
+ if response [0 ] in _INVALIDATION_MESSAGE :
272
+ res = self .invalidation_push_handler_func (response )
273
+ else :
274
+ res = self .pubsub_push_handler_func (response )
275
+ if not push_request :
276
+ return await self ._read_response (
277
+ disable_decoding = disable_decoding , push_request = push_request
278
+ )
279
+ else :
280
+ return res
281
+
282
+ def set_pubsub_push_handler (self , pubsub_push_handler_func ):
283
+ self .pubsub_push_handler_func = pubsub_push_handler_func
284
+
285
+ def set_invalidation_push_handler (self , invalidations_push_handler_func ):
286
+ self .invalidation_push_handler_func = invalidations_push_handler_func
0 commit comments