Skip to content

Commit 440bf5b

Browse files
tmigimatsutokifig
authored andcommitted
feat: encode tensors in Redis
1 parent b3918ef commit 440bf5b

File tree

1 file changed

+142
-40
lines changed

1 file changed

+142
-40
lines changed

ctrlutils/redis.py

Lines changed: 142 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
Authors: Toki Migimatsu
88
"""
99

10-
import typing
10+
from typing import Optional, Union
1111

12-
import redis
1312
import numpy as np
13+
import redis
1414

1515

16-
class StringStream:
16+
class InputStringStream:
1717
def __init__(self, buffer: bytes):
1818
self._buffer = buffer
1919
self._idx = 0
2020

21-
def getbuffer(self) -> bytes:
21+
def peek_remaining(self) -> bytes:
2222
return self._buffer[self._idx :]
2323

2424
def read(self, num_bytes: int) -> bytes:
@@ -27,13 +27,27 @@ def read(self, num_bytes: int) -> bytes:
2727
return self._buffer[idx_prev : self._idx]
2828

2929
def read_word(self) -> str:
30-
len_word = self.getbuffer().index(b" ")
30+
len_word = self.peek_remaining().index(b" ")
3131
word = self.read(len_word)
3232
self.read(1) # Consume space.
3333
return word.decode("utf8")
3434

3535

36-
def decode_matlab(s: typing.Union[str, bytes]) -> np.ndarray:
36+
class OutputStringStream:
37+
def __init__(self, buffer: Optional[list[bytes]] = None) -> None:
38+
self._buffer = [] if buffer is None else buffer
39+
40+
def write(self, b: Union[bytes, str]) -> None:
41+
if isinstance(b, str):
42+
b = b.encode("utf8")
43+
self._buffer.append(b)
44+
45+
def flush(self) -> bytes:
46+
self._buffer = [b"".join(self._buffer)]
47+
return self._buffer[0]
48+
49+
50+
def decode_matlab(s: Union[str, bytes]) -> np.ndarray:
3751
if isinstance(s, bytes):
3852
s = s.decode("utf8")
3953
s = s.strip()
@@ -50,7 +64,7 @@ def encode_matlab(A: np.ndarray) -> str:
5064
def decode_opencv(b: bytes) -> np.ndarray:
5165
import cv2
5266

53-
ss = StringStream(b)
67+
ss = InputStringStream(b)
5468

5569
mat_type = int(ss.read_word())
5670
if mat_type in {
@@ -68,18 +82,17 @@ def decode_opencv(b: bytes) -> np.ndarray:
6882
cv2.CV_32FC4,
6983
}:
7084
size = int(ss.read_word())
71-
buffer = np.frombuffer(ss.getbuffer(), dtype=np.uint8)
85+
buffer = np.frombuffer(ss.peek_remaining(), dtype=np.uint8)
7286
img = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED)
7387
else:
74-
rows = int(ss.read_word())
75-
cols = int(ss.read_word())
76-
buffer = np.frombuffer(ss.getbuffer(), dtype=np.uint8)
77-
img = buffer.reshape((rows, cols))
88+
raise ValueError(f"Unsupported image type {mat_type}.")
7889

7990
return img
8091

92+
8193
def encode_opencv(img: np.ndarray) -> bytes:
8294
import cv2
95+
8396
def np_to_cv_type(img: np.ndarray):
8497
if img.dtype == np.uint8:
8598
if len(img.shape) == 2 or img.shape[2] == 1:
@@ -108,68 +121,157 @@ def np_to_cv_type(img: np.ndarray):
108121
return cv2.CV_32FC3
109122
elif img.shape[2] == 4:
110123
return cv2.CV_32FC4
111-
raise ArgumentError("Unsupported image type {img.dtype}, {img.shape[2]} channels")
124+
raise ValueError(
125+
f"Unsupported image type {img.dtype}, {img.shape[2] if len(img.shape) > 2 else 1} channels"
126+
)
112127

113-
buffer = []
114128
type_img = np_to_cv_type(img)
115-
buffer.append(f"{type_img} ".encode("utf8"))
116129

117130
if img.dtype in (np.uint8, np.uint16):
118-
_, png = cv2.imencode(".png", img)
119-
buffer.append(f"{len(png)} ".encode("utf8"))
120-
buffer.append(png.tobytes())
131+
_, data = cv2.imencode(".png", img)
121132
elif img.dtype == np.float32:
122-
_, exr = cv2.imencode(".exr", img)
123-
buffer.append(f"{len(exr)} ".encode("utf8"))
124-
buffer.append(exr.tobytes())
125-
else:
126-
buffer.append(f"{img.shape[0]} {img.shape[1]} ".encode("utf8"))
127-
buffer.append(img.tobytes())
133+
_, data = cv2.imencode(".exr", img)
134+
135+
ss = OutputStringStream()
136+
ss.write(f"{type_img} {len(data)} ")
137+
ss.write(data.tobytes())
138+
139+
return ss.flush()
140+
141+
142+
def decode_tensor(b: bytes) -> np.ndarray:
143+
ss = InputStringStream(b)
128144

129-
return b"".join(buffer)
145+
# Parse shape opening delimiter.
146+
w = ss.read_word()
147+
if w != "(":
148+
raise ValueError(f"Expected '(' at index 0 but found {w} instead.")
149+
150+
# Parse shape.
151+
shape = []
152+
while True:
153+
w = ss.read_word()
154+
if w == ")":
155+
break
156+
shape.append(int(w))
157+
158+
# Parse dtype
159+
dtype = np.dtype(ss.read_word())
160+
161+
# Parse data.
162+
tensor = np.frombuffer(ss.peek_remaining(), dtype=dtype)
163+
tensor = tensor.reshape(shape)
164+
165+
return tensor
166+
167+
168+
def encode_tensor(tensor: np.ndarray) -> bytes:
169+
ss = OutputStringStream()
170+
shape = " ".join(map(str, tensor.shape))
171+
dtype = str(tensor.dtype)
172+
ss.write(f"( {shape} ) {dtype} ")
173+
ss.write(tensor.tobytes())
174+
return ss.flush()
130175

131176

132177
class RedisClient(redis.Redis):
133178
def __init__(
134179
self,
135180
host: str = "127.0.0.1",
136181
port: int = 6379,
137-
password: typing.Optional[str] = None,
138-
):
182+
password: Optional[str] = None,
183+
) -> None:
139184
super().__init__(host=host, port=port, password=password)
140185

141-
def pipeline(self, transaction=True, shard_hint=None):
186+
def pipeline(self, transaction: bool = True, shard_hint=None) -> "Pipeline":
142187
return Pipeline(
143188
self.connection_pool, self.response_callbacks, transaction, shard_hint
144189
)
145190

191+
def get(self, key: str, decode: Optional[str] = None) -> str:
192+
val = super().get(key)
193+
if decode is not None:
194+
return val.decode("utf8")
195+
return val
196+
146197
def get_image(self, key: str) -> np.ndarray:
147198
"""Gets a cv::Mat image from Redis."""
148-
val = self.get(key)
149-
return decode_opencv(val)
199+
b_val = super().get(key)
200+
return decode_opencv(b_val)
150201

151-
def set_image(self, key: str, val: np.ndarray):
202+
def set_image(self, key: str, val: np.ndarray) -> bool:
152203
"""Sets a cv::Mat in Redis."""
153-
self.set(key, encode_opencv(val))
204+
return self.set(key, encode_opencv(val))
154205

155206
def get_matrix(self, key: str) -> np.ndarray:
156207
"""Gets an Eigen::Matrix or Eigen::Vector from Redis."""
157-
val = self.get(key).decode("utf8")
158-
return decode_matlab(val)
208+
b_val = self.get(key)
209+
return decode_matlab(b_val)
159210

160-
def set_matrix(self, key: str, val: np.ndarray):
211+
def set_matrix(self, key: str, val: np.ndarray) -> bool:
161212
"""Sets an Eigen::Matrix or Eigen::Vector in Redis."""
162-
self.set(key, encode_matlab(val))
213+
return self.set(key, encode_matlab(val))
214+
215+
def get_tensor(self, key: str) -> np.ndarray:
216+
"""Gets a np.ndarray from Redis."""
217+
b_val = super().get(key)
218+
return decode_tensor(b_val)
219+
220+
def set_tensor(self, key: str, val: np.ndarray) -> bool:
221+
"""Sets a np.ndarray in Redis."""
222+
return self.set(key, encode_tensor(val))
163223

164224

165225
class Pipeline(redis.client.Pipeline):
166226
def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
167227
super().__init__(connection_pool, response_callbacks, transaction, shard_hint)
228+
self._decode_fns = []
229+
230+
def get(self, key: str, decode: Optional[str] = None) -> "Pipeline":
231+
super().get(key)
232+
self._decode_fns.append(None if decode is None else lambda b: b.decode(decode))
233+
return self
234+
235+
def set(self, key: str, val) -> "Pipeline":
236+
super().set(key, val)
237+
self._decode_fns.append(None)
238+
return self
168239

169-
def set_image(self, key: str, val: np.ndarray):
240+
def get_image(self, key: str) -> "Pipeline":
241+
"""Gets a cv::Mat from Redis."""
242+
super().get(key)
243+
self._decode_fns.append(decode_opencv)
244+
return self
245+
246+
def set_image(self, key: str, val: np.ndarray) -> "Pipeline":
170247
"""Sets a cv::Mat in Redis."""
171-
self.set(key, encode_opencv(val))
248+
return self.set(key, encode_opencv(val))
249+
250+
def get_matrix(self, key: str) -> "Pipeline":
251+
"""Gets an Eigen::Matrix or Eigen::Vector from Redis."""
252+
super().get(key)
253+
self._decode_fns.append(decode_matlab)
254+
return self
172255

173-
def set_matrix(self, key: str, val: np.ndarray):
256+
def set_matrix(self, key: str, val: np.ndarray) -> "Pipeline":
174257
"""Sets an Eigen::Matrix or Eigen::Vector in Redis."""
175-
self.set(key, encode_matlab(val))
258+
return self.set(key, encode_matlab(val))
259+
260+
def get_tensor(self, key: str) -> "Pipeline":
261+
"""Gets a tensor from Redis."""
262+
super().get(key)
263+
self._decode_fns.append(decode_tensor)
264+
return self
265+
266+
def set_tensor(self, key: str, val: np.ndarray) -> "Pipeline":
267+
"""Sets a tensor in Redis."""
268+
return self.set(key, encode_tensor(val))
269+
270+
def execute(self) -> list:
271+
responses = super().execute()
272+
decoded_responses = [
273+
decode_fn(response) if decode_fn is not None else response
274+
for response, decode_fn in zip(responses, self._decode_fns)
275+
]
276+
self._decode_fns = []
277+
return decoded_responses

0 commit comments

Comments
 (0)