7
7
Authors: Toki Migimatsu
8
8
"""
9
9
10
- import typing
10
+ from typing import Optional , Union
11
11
12
- import redis
13
12
import numpy as np
13
+ import redis
14
14
15
15
16
- class StringStream :
16
+ class InputStringStream :
17
17
def __init__ (self , buffer : bytes ):
18
18
self ._buffer = buffer
19
19
self ._idx = 0
20
20
21
- def getbuffer (self ) -> bytes :
21
+ def peek_remaining (self ) -> bytes :
22
22
return self ._buffer [self ._idx :]
23
23
24
24
def read (self , num_bytes : int ) -> bytes :
@@ -27,13 +27,27 @@ def read(self, num_bytes: int) -> bytes:
27
27
return self ._buffer [idx_prev : self ._idx ]
28
28
29
29
def read_word (self ) -> str :
30
- len_word = self .getbuffer ().index (b" " )
30
+ len_word = self .peek_remaining ().index (b" " )
31
31
word = self .read (len_word )
32
32
self .read (1 ) # Consume space.
33
33
return word .decode ("utf8" )
34
34
35
35
36
- def decode_matlab (s : typing .Union [str , bytes ]) -> np .ndarray :
36
+ class OutputStringStream :
37
+ def __init__ (self , buffer : Optional [list [bytes ]] = None ) -> None :
38
+ self ._buffer = [] if buffer is None else buffer
39
+
40
+ def write (self , b : Union [bytes , str ]) -> None :
41
+ if isinstance (b , str ):
42
+ b = b .encode ("utf8" )
43
+ self ._buffer .append (b )
44
+
45
+ def flush (self ) -> bytes :
46
+ self ._buffer = [b"" .join (self ._buffer )]
47
+ return self ._buffer [0 ]
48
+
49
+
50
+ def decode_matlab (s : Union [str , bytes ]) -> np .ndarray :
37
51
if isinstance (s , bytes ):
38
52
s = s .decode ("utf8" )
39
53
s = s .strip ()
@@ -50,7 +64,7 @@ def encode_matlab(A: np.ndarray) -> str:
50
64
def decode_opencv (b : bytes ) -> np .ndarray :
51
65
import cv2
52
66
53
- ss = StringStream (b )
67
+ ss = InputStringStream (b )
54
68
55
69
mat_type = int (ss .read_word ())
56
70
if mat_type in {
@@ -68,18 +82,17 @@ def decode_opencv(b: bytes) -> np.ndarray:
68
82
cv2 .CV_32FC4 ,
69
83
}:
70
84
size = int (ss .read_word ())
71
- buffer = np .frombuffer (ss .getbuffer (), dtype = np .uint8 )
85
+ buffer = np .frombuffer (ss .peek_remaining (), dtype = np .uint8 )
72
86
img = cv2 .imdecode (buffer , cv2 .IMREAD_UNCHANGED )
73
87
else :
74
- rows = int (ss .read_word ())
75
- cols = int (ss .read_word ())
76
- buffer = np .frombuffer (ss .getbuffer (), dtype = np .uint8 )
77
- img = buffer .reshape ((rows , cols ))
88
+ raise ValueError (f"Unsupported image type { mat_type } ." )
78
89
79
90
return img
80
91
92
+
81
93
def encode_opencv (img : np .ndarray ) -> bytes :
82
94
import cv2
95
+
83
96
def np_to_cv_type (img : np .ndarray ):
84
97
if img .dtype == np .uint8 :
85
98
if len (img .shape ) == 2 or img .shape [2 ] == 1 :
@@ -108,68 +121,157 @@ def np_to_cv_type(img: np.ndarray):
108
121
return cv2 .CV_32FC3
109
122
elif img .shape [2 ] == 4 :
110
123
return cv2 .CV_32FC4
111
- raise ArgumentError ("Unsupported image type {img.dtype}, {img.shape[2]} channels" )
124
+ raise ValueError (
125
+ f"Unsupported image type { img .dtype } , { img .shape [2 ] if len (img .shape ) > 2 else 1 } channels"
126
+ )
112
127
113
- buffer = []
114
128
type_img = np_to_cv_type (img )
115
- buffer .append (f"{ type_img } " .encode ("utf8" ))
116
129
117
130
if img .dtype in (np .uint8 , np .uint16 ):
118
- _ , png = cv2 .imencode (".png" , img )
119
- buffer .append (f"{ len (png )} " .encode ("utf8" ))
120
- buffer .append (png .tobytes ())
131
+ _ , data = cv2 .imencode (".png" , img )
121
132
elif img .dtype == np .float32 :
122
- _ , exr = cv2 .imencode (".exr" , img )
123
- buffer .append (f"{ len (exr )} " .encode ("utf8" ))
124
- buffer .append (exr .tobytes ())
125
- else :
126
- buffer .append (f"{ img .shape [0 ]} { img .shape [1 ]} " .encode ("utf8" ))
127
- buffer .append (img .tobytes ())
133
+ _ , data = cv2 .imencode (".exr" , img )
134
+
135
+ ss = OutputStringStream ()
136
+ ss .write (f"{ type_img } { len (data )} " )
137
+ ss .write (data .tobytes ())
138
+
139
+ return ss .flush ()
140
+
141
+
142
+ def decode_tensor (b : bytes ) -> np .ndarray :
143
+ ss = InputStringStream (b )
128
144
129
- return b"" .join (buffer )
145
+ # Parse shape opening delimiter.
146
+ w = ss .read_word ()
147
+ if w != "(" :
148
+ raise ValueError (f"Expected '(' at index 0 but found { w } instead." )
149
+
150
+ # Parse shape.
151
+ shape = []
152
+ while True :
153
+ w = ss .read_word ()
154
+ if w == ")" :
155
+ break
156
+ shape .append (int (w ))
157
+
158
+ # Parse dtype
159
+ dtype = np .dtype (ss .read_word ())
160
+
161
+ # Parse data.
162
+ tensor = np .frombuffer (ss .peek_remaining (), dtype = dtype )
163
+ tensor = tensor .reshape (shape )
164
+
165
+ return tensor
166
+
167
+
168
+ def encode_tensor (tensor : np .ndarray ) -> bytes :
169
+ ss = OutputStringStream ()
170
+ shape = " " .join (map (str , tensor .shape ))
171
+ dtype = str (tensor .dtype )
172
+ ss .write (f"( { shape } ) { dtype } " )
173
+ ss .write (tensor .tobytes ())
174
+ return ss .flush ()
130
175
131
176
132
177
class RedisClient (redis .Redis ):
133
178
def __init__ (
134
179
self ,
135
180
host : str = "127.0.0.1" ,
136
181
port : int = 6379 ,
137
- password : typing . Optional [str ] = None ,
138
- ):
182
+ password : Optional [str ] = None ,
183
+ ) -> None :
139
184
super ().__init__ (host = host , port = port , password = password )
140
185
141
- def pipeline (self , transaction = True , shard_hint = None ):
186
+ def pipeline (self , transaction : bool = True , shard_hint = None ) -> "Pipeline" :
142
187
return Pipeline (
143
188
self .connection_pool , self .response_callbacks , transaction , shard_hint
144
189
)
145
190
191
+ def get (self , key : str , decode : Optional [str ] = None ) -> str :
192
+ val = super ().get (key )
193
+ if decode is not None :
194
+ return val .decode ("utf8" )
195
+ return val
196
+
146
197
def get_image (self , key : str ) -> np .ndarray :
147
198
"""Gets a cv::Mat image from Redis."""
148
- val = self .get (key )
149
- return decode_opencv (val )
199
+ b_val = super () .get (key )
200
+ return decode_opencv (b_val )
150
201
151
- def set_image (self , key : str , val : np .ndarray ):
202
+ def set_image (self , key : str , val : np .ndarray ) -> bool :
152
203
"""Sets a cv::Mat in Redis."""
153
- self .set (key , encode_opencv (val ))
204
+ return self .set (key , encode_opencv (val ))
154
205
155
206
def get_matrix (self , key : str ) -> np .ndarray :
156
207
"""Gets an Eigen::Matrix or Eigen::Vector from Redis."""
157
- val = self .get (key ). decode ( "utf8" )
158
- return decode_matlab (val )
208
+ b_val = self .get (key )
209
+ return decode_matlab (b_val )
159
210
160
- def set_matrix (self , key : str , val : np .ndarray ):
211
+ def set_matrix (self , key : str , val : np .ndarray ) -> bool :
161
212
"""Sets an Eigen::Matrix or Eigen::Vector in Redis."""
162
- self .set (key , encode_matlab (val ))
213
+ return self .set (key , encode_matlab (val ))
214
+
215
+ def get_tensor (self , key : str ) -> np .ndarray :
216
+ """Gets a np.ndarray from Redis."""
217
+ b_val = super ().get (key )
218
+ return decode_tensor (b_val )
219
+
220
+ def set_tensor (self , key : str , val : np .ndarray ) -> bool :
221
+ """Sets a np.ndarray in Redis."""
222
+ return self .set (key , encode_tensor (val ))
163
223
164
224
165
225
class Pipeline (redis .client .Pipeline ):
166
226
def __init__ (self , connection_pool , response_callbacks , transaction , shard_hint ):
167
227
super ().__init__ (connection_pool , response_callbacks , transaction , shard_hint )
228
+ self ._decode_fns = []
229
+
230
+ def get (self , key : str , decode : Optional [str ] = None ) -> "Pipeline" :
231
+ super ().get (key )
232
+ self ._decode_fns .append (None if decode is None else lambda b : b .decode (decode ))
233
+ return self
234
+
235
+ def set (self , key : str , val ) -> "Pipeline" :
236
+ super ().set (key , val )
237
+ self ._decode_fns .append (None )
238
+ return self
168
239
169
- def set_image (self , key : str , val : np .ndarray ):
240
+ def get_image (self , key : str ) -> "Pipeline" :
241
+ """Gets a cv::Mat from Redis."""
242
+ super ().get (key )
243
+ self ._decode_fns .append (decode_opencv )
244
+ return self
245
+
246
+ def set_image (self , key : str , val : np .ndarray ) -> "Pipeline" :
170
247
"""Sets a cv::Mat in Redis."""
171
- self .set (key , encode_opencv (val ))
248
+ return self .set (key , encode_opencv (val ))
249
+
250
+ def get_matrix (self , key : str ) -> "Pipeline" :
251
+ """Gets an Eigen::Matrix or Eigen::Vector from Redis."""
252
+ super ().get (key )
253
+ self ._decode_fns .append (decode_matlab )
254
+ return self
172
255
173
- def set_matrix (self , key : str , val : np .ndarray ):
256
+ def set_matrix (self , key : str , val : np .ndarray ) -> "Pipeline" :
174
257
"""Sets an Eigen::Matrix or Eigen::Vector in Redis."""
175
- self .set (key , encode_matlab (val ))
258
+ return self .set (key , encode_matlab (val ))
259
+
260
+ def get_tensor (self , key : str ) -> "Pipeline" :
261
+ """Gets a tensor from Redis."""
262
+ super ().get (key )
263
+ self ._decode_fns .append (decode_tensor )
264
+ return self
265
+
266
+ def set_tensor (self , key : str , val : np .ndarray ) -> "Pipeline" :
267
+ """Sets a tensor in Redis."""
268
+ return self .set (key , encode_tensor (val ))
269
+
270
+ def execute (self ) -> list :
271
+ responses = super ().execute ()
272
+ decoded_responses = [
273
+ decode_fn (response ) if decode_fn is not None else response
274
+ for response , decode_fn in zip (responses , self ._decode_fns )
275
+ ]
276
+ self ._decode_fns = []
277
+ return decoded_responses
0 commit comments