Skip to content

Commit 23c92ff

Browse files
tmigimatsutokifig
authored andcommitted
fix: pack bool tensors
1 parent d198e0d commit 23c92ff

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

ctrlutils/redis.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ def decode_tensor(b: bytes) -> np.ndarray:
159159
dtype = np.dtype(ss.read_word())
160160

161161
# Parse data.
162-
tensor = np.frombuffer(ss.peek_remaining(), dtype=dtype)
162+
if dtype == "bool":
163+
tensor = np.frombuffer(ss.peek_remaining(), dtype=np.uint8)
164+
tensor = np.unpackbits(tensor, count=np.prod(shape)).astype(bool)
165+
else:
166+
tensor = np.frombuffer(ss.peek_remaining(), dtype=dtype)
163167
tensor = tensor.reshape(shape)
164168

165169
return tensor
@@ -170,6 +174,8 @@ def encode_tensor(tensor: np.ndarray) -> bytes:
170174
shape = " ".join(map(str, tensor.shape))
171175
dtype = str(tensor.dtype)
172176
ss.write(f"( {shape} ) {dtype} ")
177+
if dtype == "bool":
178+
tensor = np.packbits(tensor)
173179
ss.write(tensor.tobytes())
174180
return ss.flush()
175181

0 commit comments

Comments
 (0)