File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -159,7 +159,11 @@ def decode_tensor(b: bytes) -> np.ndarray:
159
159
dtype = np .dtype (ss .read_word ())
160
160
161
161
# Parse data.
162
- tensor = np .frombuffer (ss .peek_remaining (), dtype = dtype )
162
+ if dtype == "bool" :
163
+ tensor = np .frombuffer (ss .peek_remaining (), dtype = np .uint8 )
164
+ tensor = np .unpackbits (tensor , count = np .prod (shape )).astype (bool )
165
+ else :
166
+ tensor = np .frombuffer (ss .peek_remaining (), dtype = dtype )
163
167
tensor = tensor .reshape (shape )
164
168
165
169
return tensor
@@ -170,6 +174,8 @@ def encode_tensor(tensor: np.ndarray) -> bytes:
170
174
shape = " " .join (map (str , tensor .shape ))
171
175
dtype = str (tensor .dtype )
172
176
ss .write (f"( { shape } ) { dtype } " )
177
+ if dtype == "bool" :
178
+ tensor = np .packbits (tensor )
173
179
ss .write (tensor .tobytes ())
174
180
return ss .flush ()
175
181
You can’t perform that action at this time.
0 commit comments