Skip to content

Commit 60b1db8

Browse files
Johannes Ballécopybara-github
authored andcommitted
Simplified error handling for bit_coder and run_length_gamma_kernels.
PiperOrigin-RevId: 470074035 Change-Id: I9452966aef1f56a5b0c7570b24516ee60bd2b64c
1 parent 0689a6b commit 60b1db8

File tree

4 files changed

+134
-165
lines changed

4 files changed

+134
-165
lines changed

tensorflow_compression/cc/kernels/run_length_gamma_kernels.cc

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
#include <algorithm>
1818
#include <array>
19+
#include <cassert>
1920
#include <cmath>
2021
#include <cstdint>
2122
#include <cstring>
@@ -28,17 +29,16 @@ limitations under the License.
2829
#include "tensorflow/core/framework/tensor.h"
2930
#include "tensorflow/core/framework/tensor_shape.h"
3031
#include "tensorflow/core/framework/tensor_types.h"
31-
#include "tensorflow/core/lib/core/errors.h"
32-
#include "tensorflow/core/lib/core/status.h"
3332
#include "tensorflow/core/platform/logging.h"
34-
#include "tensorflow/core/platform/macros.h"
33+
#include "tensorflow/core/platform/status.h"
3534
#include "tensorflow/core/platform/types.h"
3635
#include "tensorflow_compression/cc/lib/bit_coder.h"
3736

3837
namespace tensorflow_compression {
3938
namespace {
4039
namespace errors = tensorflow::errors;
4140
using tensorflow::DEVICE_CPU;
41+
using tensorflow::FromAbslStatus;
4242
using tensorflow::OpKernel;
4343
using tensorflow::OpKernelConstruction;
4444
using tensorflow::OpKernelContext;
@@ -69,36 +69,38 @@ class RunLengthGammaEncodeOp : public OpKernel {
6969
// any zeros were present in the input array, then the encoded size would be
7070
// strictly smaller by kMaxGammaBits and bigger by the difference in
7171
// encoding (the existing zero run length + 1).
72-
BitWriter enc;
73-
enc.Allocate(data.size() * (2 + enc.kMaxGammaBits));
72+
BitWriter enc(data.size() * (2 + enc.kMaxGammaBits));
7473
// Save number of zeros + 1 preceding next non-zero element.
7574
uint32_t zero_ct = 1;
7675

7776
// Iterate through data tensor.
78-
for (size_t i = 0; i < data.size(); i++) {
77+
for (int64_t i = 0; i < data.size(); i++) {
78+
int32_t sample = data(i);
7979
// Increment zero count.
80-
if (data(i) == 0) {
80+
if (sample == 0) {
8181
zero_ct += 1;
8282
} else {
8383
// Encode run length of zeros.
8484
enc.WriteGamma(zero_ct);
8585
// Encode sign of value.
86-
enc.WriteOneBit(data(i) > 0);
86+
enc.WriteOneBit(sample > 0);
8787
// Encode magnitude of value.
88-
DCHECK_NE(data(i), std::numeric_limits<int32_t>::min());
89-
enc.WriteGamma(std::abs(data(i)));
90-
// Reset zero count (1 because Gamma cannot encode 0).
88+
if (sample == std::numeric_limits<int32_t>::min()) {
89+
// We can't encode int32 minimum. Encode closest value instead.
90+
sample += 1;
91+
}
92+
enc.WriteGamma(std::abs(sample));
93+
// Reset zero count (1 because gamma cannot encode 0).
9194
zero_ct = 1;
9295
}
9396
}
9497
if (zero_ct > 1) {
9598
enc.WriteGamma(zero_ct);
9699
}
97100

98-
// Pad any remaining bits in last byte with 0.
99-
enc.ZeroPadToByte();
100101
// Write encoded bitstring to code.
101-
code->assign(enc.GetData(), enc.GetBytesWritten());
102+
auto encoded = enc.GetData();
103+
code->assign(encoded.data(), encoded.size());
102104
}
103105
};
104106

@@ -137,32 +139,33 @@ class RunLengthGammaDecodeOp : public OpKernel {
137139
// Fill data tensor with zeros.
138140
std::memset(data.data(), 0, data.size() * sizeof(data(0)));
139141

140-
for (size_t i = 0; i < data.size(); i++) {
142+
for (int64_t i = 0; i < data.size(); i++) {
141143
// Get number of zeros.
142-
uint32_t num_zeros = dec.ReadGamma();
144+
auto num_zeros = dec.ReadGamma();
145+
OP_REQUIRES(context, num_zeros.ok(), FromAbslStatus(num_zeros.status()));
146+
143147
// Advance the index to the next non-zero element.
144-
i += num_zeros - 1;
148+
i += *num_zeros - 1;
145149

146150
// Account for case where the last element is zero.
147-
if (i == data.size()) {
151+
// Check if past the last element.
152+
if (i >= data.size()) {
153+
OP_REQUIRES(context, i == data.size(),
154+
errors::DataLoss("Decoded past end of tensor."));
148155
break;
149156
}
150-
// TODO(nicolemitchell): return error status instead of crashing
151-
DCHECK_LT(i, data.size());
152157

153158
// Get sign of value.
154-
uint32_t positive = dec.ReadOneBit();
159+
auto positive = dec.ReadOneBit();
160+
OP_REQUIRES(context, positive.ok(), FromAbslStatus(positive.status()));
155161

156-
// Get value.
157-
uint32_t value = dec.ReadGamma();
162+
// Get magnitude.
163+
auto magnitude = dec.ReadGamma();
164+
OP_REQUIRES(context, magnitude.ok(), FromAbslStatus(magnitude.status()));
158165

159166
// Write value to data tensor element at index.
160-
DCHECK_LE(value, std::numeric_limits<int32_t>::max());
161-
data(i) = positive ? value : -static_cast<int32_t>(value);
167+
data(i) = *positive ? *magnitude : -*magnitude;
162168
}
163-
164-
OP_REQUIRES(context, dec.Close().ok(),
165-
tensorflow::errors::DataLoss("Decoding error."));
166169
}
167170
};
168171

tensorflow_compression/cc/kernels/run_length_gamma_kernels_test.cc

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,9 @@ limitations under the License.
3535
#include "tensorflow/core/graph/node_builder.h"
3636
#include "tensorflow/core/graph/testlib.h"
3737
#include "tensorflow/core/kernels/ops_testutil.h"
38-
#include "tensorflow/core/lib/core/bits.h"
3938
#include "tensorflow/core/lib/core/status_test_util.h"
40-
#include "tensorflow/core/lib/random/simple_philox.h"
4139
#include "tensorflow/core/platform/stacktrace_handler.h"
4240
#include "tensorflow/core/platform/test.h"
43-
#include "tensorflow/core/public/session.h"
44-
#include "tensorflow/core/public/session_options.h"
4541
#include "tensorflow_compression/cc/lib/bit_coder.h"
4642

4743
namespace tensorflow_compression {
@@ -198,36 +194,35 @@ TEST_F(BitCodingOpsTest, ManualEncodeWithBitcodingLibrary) {
198194
TF_ASSERT_OK(RunEncodeOp({data_tensor}, &code_tensor));
199195

200196
// Use bitcoding library to encode data.
201-
BitWriter enc_ = BitWriter();
202-
enc_.Allocate(16);
197+
BitWriter enc_ = BitWriter(16);
203198
enc_.WriteGamma(2); // one zero
204199
enc_.WriteOneBit(0); // negative
205200
enc_.WriteGamma(3); // 3
206201
enc_.WriteGamma(1); // no zeros
207202
enc_.WriteOneBit(1); // positive
208203
enc_.WriteGamma(1); // 1
209-
enc_.ZeroPadToByte();
210204
Tensor expected_code_tensor(DT_STRING, {});
211-
expected_code_tensor.scalar<tstring>()().assign(enc_.GetData(), 2);
205+
auto encoded = enc_.GetData();
206+
expected_code_tensor.scalar<tstring>()().assign(encoded.data(),
207+
encoded.size());
212208

213209
// Check that code_tensor has expected value.
214210
test::ExpectTensorEqual<tstring>(code_tensor, expected_code_tensor);
215211
}
216212

217213
TEST_F(BitCodingOpsTest, ManualDecodeWithBitcodingLibrary) {
218214
// Use bitcoding library to manually encode [-3, 1, 0, 0] into code.
219-
BitWriter enc_ = BitWriter();
220-
enc_.Allocate(16);
215+
BitWriter enc_ = BitWriter(16);
221216
enc_.WriteGamma(1); // no zeros
222217
enc_.WriteOneBit(0); // negative
223218
enc_.WriteGamma(3); // 3
224219
enc_.WriteGamma(1); // no zeros
225220
enc_.WriteOneBit(1); // positive
226221
enc_.WriteGamma(1); // 1
227222
enc_.WriteGamma(3); // two zeros
228-
enc_.ZeroPadToByte();
229223
Tensor code_tensor(DT_STRING, {});
230-
code_tensor.scalar<tstring>()().assign(enc_.GetData(), 2);
224+
auto encoded = enc_.GetData();
225+
code_tensor.scalar<tstring>()().assign(encoded.data(), encoded.size());
231226

232227
Tensor shape_tensor(DT_INT32, {1});
233228
shape_tensor.flat<int32_t>().setValues({4});
@@ -242,7 +237,6 @@ TEST_F(BitCodingOpsTest, ManualDecodeWithBitcodingLibrary) {
242237
test::ExpectTensorEqual<int32_t>(data_tensor, expected_data_tensor);
243238
}
244239

245-
// TODO(nicolemitchell) Strengthen these consistency checks.
246240
TEST_F(BitCodingOpsTest, EncodeConsistent) {
247241
Tensor data_tensor(DT_INT32, {4});
248242
data_tensor.flat<int32_t>().setValues({-6, 3, 0, 0});
@@ -259,7 +253,6 @@ TEST_F(BitCodingOpsTest, EncodeConsistent) {
259253
}
260254

261255
TEST_F(BitCodingOpsTest, DecodeConsistent) {
262-
// Manually encode some data into code.
263256
char code[] = {0b11010001, 0b01101101}; // [-6, 3, 0, 0]
264257

265258
Tensor code_tensor(DT_STRING, {});
@@ -277,6 +270,14 @@ TEST_F(BitCodingOpsTest, DecodeConsistent) {
277270
// Check that decoded data has expected values.
278271
test::ExpectTensorEqual<int32_t>(data_tensor, expected_data_tensor);
279272
}
273+
274+
// TODO(nicolemitchell,jonycgn) Add more corner cases to unit tests.
275+
// Examples: decode empty string (null pointer), decode strings that end
276+
// prematurely, decode long string of zeros that causes overflow in ReadGamma,
277+
// decode incorrect run length that exceeds tensor size, encode int32::min
278+
// tensor, encode tensor with very large values to ensure it doesn't exceed
279+
// allocated buffer, encode gamma values <= 0, ...
280+
280281
} // namespace
281282
} // namespace tensorflow_compression
282283

0 commit comments

Comments
 (0)