@@ -16,6 +16,7 @@ limitations under the License.
16
16
17
17
#include < algorithm>
18
18
#include < array>
19
+ #include < cassert>
19
20
#include < cmath>
20
21
#include < cstdint>
21
22
#include < cstring>
@@ -28,17 +29,16 @@ limitations under the License.
28
29
#include " tensorflow/core/framework/tensor.h"
29
30
#include " tensorflow/core/framework/tensor_shape.h"
30
31
#include " tensorflow/core/framework/tensor_types.h"
31
- #include " tensorflow/core/lib/core/errors.h"
32
- #include " tensorflow/core/lib/core/status.h"
33
32
#include " tensorflow/core/platform/logging.h"
34
- #include " tensorflow/core/platform/macros .h"
33
+ #include " tensorflow/core/platform/status .h"
35
34
#include " tensorflow/core/platform/types.h"
36
35
#include " tensorflow_compression/cc/lib/bit_coder.h"
37
36
38
37
namespace tensorflow_compression {
39
38
namespace {
40
39
namespace errors = tensorflow::errors;
41
40
using tensorflow::DEVICE_CPU;
41
+ using tensorflow::FromAbslStatus;
42
42
using tensorflow::OpKernel;
43
43
using tensorflow::OpKernelConstruction;
44
44
using tensorflow::OpKernelContext;
@@ -69,36 +69,38 @@ class RunLengthGammaEncodeOp : public OpKernel {
69
69
// any zeros were present in the input array, then the encoded size would be
70
70
// strictly smaller by kMaxGammaBits and bigger by the difference in
71
71
// encoding (the existing zero run length + 1).
72
- BitWriter enc;
73
- enc.Allocate (data.size () * (2 + enc.kMaxGammaBits ));
72
+ BitWriter enc (data.size () * (2 + enc.kMaxGammaBits ));
74
73
// Save number of zeros + 1 preceding next non-zero element.
75
74
uint32_t zero_ct = 1 ;
76
75
77
76
// Iterate through data tensor.
78
- for (size_t i = 0 ; i < data.size (); i++) {
77
+ for (int64_t i = 0 ; i < data.size (); i++) {
78
+ int32_t sample = data (i);
79
79
// Increment zero count.
80
- if (data (i) == 0 ) {
80
+ if (sample == 0 ) {
81
81
zero_ct += 1 ;
82
82
} else {
83
83
// Encode run length of zeros.
84
84
enc.WriteGamma (zero_ct);
85
85
// Encode sign of value.
86
- enc.WriteOneBit (data (i) > 0 );
86
+ enc.WriteOneBit (sample > 0 );
87
87
// Encode magnitude of value.
88
- DCHECK_NE (data (i), std::numeric_limits<int32_t >::min ());
89
- enc.WriteGamma (std::abs (data (i)));
90
- // Reset zero count (1 because Gamma cannot encode 0).
88
+ if (sample == std::numeric_limits<int32_t >::min ()) {
89
+ // We can't encode int32 minimum. Encode closest value instead.
90
+ sample += 1 ;
91
+ }
92
+ enc.WriteGamma (std::abs (sample));
93
+ // Reset zero count (1 because gamma cannot encode 0).
91
94
zero_ct = 1 ;
92
95
}
93
96
}
94
97
if (zero_ct > 1 ) {
95
98
enc.WriteGamma (zero_ct);
96
99
}
97
100
98
- // Pad any remaining bits in last byte with 0.
99
- enc.ZeroPadToByte ();
100
101
// Write encoded bitstring to code.
101
- code->assign (enc.GetData (), enc.GetBytesWritten ());
102
+ auto encoded = enc.GetData ();
103
+ code->assign (encoded.data (), encoded.size ());
102
104
}
103
105
};
104
106
@@ -137,32 +139,33 @@ class RunLengthGammaDecodeOp : public OpKernel {
137
139
// Fill data tensor with zeros.
138
140
std::memset (data.data (), 0 , data.size () * sizeof (data (0 )));
139
141
140
- for (size_t i = 0 ; i < data.size (); i++) {
142
+ for (int64_t i = 0 ; i < data.size (); i++) {
141
143
// Get number of zeros.
142
- uint32_t num_zeros = dec.ReadGamma ();
144
+ auto num_zeros = dec.ReadGamma ();
145
+ OP_REQUIRES (context, num_zeros.ok (), FromAbslStatus (num_zeros.status ()));
146
+
143
147
// Advance the index to the next non-zero element.
144
- i += num_zeros - 1 ;
148
+ i += * num_zeros - 1 ;
145
149
146
150
// Account for case where the last element is zero.
147
- if (i == data.size ()) {
151
+ // Check if past the last element.
152
+ if (i >= data.size ()) {
153
+ OP_REQUIRES (context, i == data.size (),
154
+ errors::DataLoss (" Decoded past end of tensor." ));
148
155
break ;
149
156
}
150
- // TODO(nicolemitchell): return error status instead of crashing
151
- DCHECK_LT (i, data.size ());
152
157
153
158
// Get sign of value.
154
- uint32_t positive = dec.ReadOneBit ();
159
+ auto positive = dec.ReadOneBit ();
160
+ OP_REQUIRES (context, positive.ok (), FromAbslStatus (positive.status ()));
155
161
156
- // Get value.
157
- uint32_t value = dec.ReadGamma ();
162
+ // Get magnitude.
163
+ auto magnitude = dec.ReadGamma ();
164
+ OP_REQUIRES (context, magnitude.ok (), FromAbslStatus (magnitude.status ()));
158
165
159
166
// Write value to data tensor element at index.
160
- DCHECK_LE (value, std::numeric_limits<int32_t >::max ());
161
- data (i) = positive ? value : -static_cast <int32_t >(value);
167
+ data (i) = *positive ? *magnitude : -*magnitude;
162
168
}
163
-
164
- OP_REQUIRES (context, dec.Close ().ok (),
165
- tensorflow::errors::DataLoss (" Decoding error." ));
166
169
}
167
170
};
168
171
0 commit comments