Skip to content

Commit 0e46883

Browse files
ssjhvJohannes Ballé
authored andcommitted
Add debug_level attributes to range-coding ops.
The default value is 1 (on). Note that debug_level=1 slows down the whole process. To measure actual performance, one should use debug_level=0. PiperOrigin-RevId: 242809688
1 parent ae364a0 commit 0e46883

File tree

6 files changed

+420
-93
lines changed

6 files changed

+420
-93
lines changed

cc/kernels/range_coding_kernels.cc

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ limitations under the License.
3131
#include "tensorflow/core/platform/logging.h"
3232
#include "tensorflow/core/platform/macros.h"
3333
#include "tensorflow/core/platform/types.h"
34-
3534
#include "tensorflow_compression/cc/kernels/range_coder.h"
3635
#include "tensorflow_compression/cc/kernels/range_coding_kernels_util.h"
3736

@@ -40,21 +39,21 @@ namespace {
4039
namespace errors = tensorflow::errors;
4140
namespace gtl = tensorflow::gtl;
4241
using tensorflow::DEVICE_CPU;
42+
using tensorflow::int16;
43+
using tensorflow::int32;
44+
using tensorflow::int64;
4345
using tensorflow::OpKernel;
4446
using tensorflow::OpKernelConstruction;
4547
using tensorflow::OpKernelContext;
4648
using tensorflow::Status;
49+
using tensorflow::string;
4750
using tensorflow::Tensor;
4851
using tensorflow::TensorShape;
4952
using tensorflow::TensorShapeUtils;
5053
using tensorflow::TTypes;
51-
using tensorflow::int16;
52-
using tensorflow::int32;
53-
using tensorflow::int64;
54-
using tensorflow::string;
55-
using tensorflow::uint8;
5654
using tensorflow::uint32;
5755
using tensorflow::uint64;
56+
using tensorflow::uint8;
5857

5958
// A helper class to iterate over data and cdf simultaneously, while cdf is
6059
// broadcasted to data.
@@ -151,14 +150,42 @@ Status CheckCdfShape(const TensorShape& data_shape,
151150
return Status::OK();
152151
}
153152

154-
// Non-incremental encoder op -------------------------------------------------
153+
tensorflow::Status CheckCdfValues(int precision,
154+
const tensorflow::Tensor& cdf_tensor) {
155+
const auto cdf = cdf_tensor.flat_inner_dims<int32, 2>();
156+
const auto size = cdf.dimension(1);
157+
if (size <= 2) {
158+
return errors::InvalidArgument("CDF size should be > 2: ", size);
159+
}
160+
161+
const int32 upper_bound = 1 << precision;
162+
for (int64 i = 0; i < cdf.dimension(0); ++i) {
163+
auto slice = tensorflow::gtl::ArraySlice<int32>(&cdf(i, 0), size);
164+
if (slice[0] != 0 || slice[size - 1] != upper_bound) {
165+
return errors::InvalidArgument("CDF should start from 0 and end at ",
166+
upper_bound, ": cdf[0]=", slice[0],
167+
", cdf[^1]=", slice[size - 1]);
168+
}
169+
for (int64 j = 0; j + 1 < size; ++j) {
170+
if (slice[j + 1] <= slice[j]) {
171+
return errors::InvalidArgument("CDF is not monotonic");
172+
}
173+
}
174+
}
175+
return tensorflow::Status::OK();
176+
}
177+
155178
class RangeEncodeOp : public OpKernel {
156179
public:
157180
explicit RangeEncodeOp(OpKernelConstruction* context) : OpKernel(context) {
158181
OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
159182
OP_REQUIRES(context, 0 < precision_ && precision_ <= 16,
160183
errors::InvalidArgument("`precision` must be in [1, 16]: ",
161184
precision_));
185+
OP_REQUIRES_OK(context, context->GetAttr("debug_level", &debug_level_));
186+
OP_REQUIRES(context, debug_level_ == 0 || debug_level_ == 1,
187+
errors::InvalidArgument("`debug_level` must be 0 or 1: ",
188+
debug_level_));
162189
}
163190

164191
void Compute(OpKernelContext* context) override {
@@ -167,6 +194,10 @@ class RangeEncodeOp : public OpKernel {
167194

168195
OP_REQUIRES_OK(context, CheckCdfShape(data.shape(), cdf.shape()));
169196

197+
if (debug_level_ > 0) {
198+
OP_REQUIRES_OK(context, CheckCdfValues(precision_, cdf));
199+
}
200+
170201
std::vector<int64> data_shape, cdf_shape;
171202
OP_REQUIRES_OK(
172203
context, MergeAxes(data.shape(), cdf.shape(), &data_shape, &cdf_shape));
@@ -177,10 +208,12 @@ class RangeEncodeOp : public OpKernel {
177208
string* output = &output_tensor->scalar<string>()();
178209

179210
switch (data_shape.size()) {
180-
#define RANGE_ENCODE_CASE(dims) \
181-
case dims: { \
182-
RangeEncodeImpl<dims>(data.flat<int16>(), data_shape, \
183-
cdf.flat_inner_dims<int32, 2>(), cdf_shape, output); \
211+
#define RANGE_ENCODE_CASE(dims) \
212+
case dims: { \
213+
OP_REQUIRES_OK(context, \
214+
RangeEncodeImpl<dims>(data.flat<int16>(), data_shape, \
215+
cdf.flat_inner_dims<int32, 2>(), \
216+
cdf_shape, output)); \
184217
} break
185218
RANGE_ENCODE_CASE(1);
186219
RANGE_ENCODE_CASE(2);
@@ -199,10 +232,11 @@ class RangeEncodeOp : public OpKernel {
199232

200233
private:
201234
template <int N>
202-
void RangeEncodeImpl(TTypes<int16>::ConstFlat data,
203-
gtl::ArraySlice<int64> data_shape,
204-
TTypes<int32>::ConstMatrix cdf,
205-
gtl::ArraySlice<int64> cdf_shape, string* output) const {
235+
tensorflow::Status RangeEncodeImpl(TTypes<int16>::ConstFlat data,
236+
gtl::ArraySlice<int64> data_shape,
237+
TTypes<int32>::ConstMatrix cdf,
238+
gtl::ArraySlice<int64> cdf_shape,
239+
string* output) const {
206240
const int64 data_size = data.size();
207241
const int64 cdf_size = cdf.size();
208242
const int64 chip_size = cdf.dimension(1);
@@ -214,8 +248,15 @@ class RangeEncodeOp : public OpKernel {
214248
const auto pair = view.Next();
215249

216250
const int64 index = *pair.first;
217-
DCHECK_GE(index, 0);
218-
DCHECK_LT(index + 1, chip_size);
251+
if (debug_level_ > 0) {
252+
if (index < 0 || chip_size <= index + 1) {
253+
return errors::InvalidArgument("'data' value not in [0, ",
254+
chip_size - 1, "): value=", index);
255+
}
256+
} else {
257+
DCHECK_GE(index, 0);
258+
DCHECK_LT(index + 1, chip_size);
259+
}
219260

220261
const int32* cdf_slice = pair.second;
221262
DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size);
@@ -226,21 +267,26 @@ class RangeEncodeOp : public OpKernel {
226267
}
227268

228269
encoder.Finalize(output);
270+
return tensorflow::Status::OK();
229271
}
230272

231273
int precision_;
274+
int debug_level_;
232275
};
233276

234277
REGISTER_KERNEL_BUILDER(Name("RangeEncode").Device(DEVICE_CPU), RangeEncodeOp);
235278

236-
// Non-incremental decoder op -------------------------------------------------
237279
class RangeDecodeOp : public OpKernel {
238280
public:
239281
explicit RangeDecodeOp(OpKernelConstruction* context) : OpKernel(context) {
240282
OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
241283
OP_REQUIRES(context, 0 < precision_ && precision_ <= 16,
242284
errors::InvalidArgument("`precision` must be in [1, 16]: ",
243285
precision_));
286+
OP_REQUIRES_OK(context, context->GetAttr("debug_level", &debug_level_));
287+
OP_REQUIRES(context, debug_level_ == 0 || debug_level_ == 1,
288+
errors::InvalidArgument("`debug_level` must be 0 or 1: ",
289+
debug_level_));
244290
}
245291

246292
void Compute(OpKernelContext* context) override {
@@ -254,11 +300,16 @@ class RangeDecodeOp : public OpKernel {
254300
OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()),
255301
errors::InvalidArgument("Invalid `shape` shape: ",
256302
shape.shape().DebugString()));
303+
257304
TensorShape output_shape;
258305
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(shape.vec<int32>(),
259306
&output_shape));
260307
OP_REQUIRES_OK(context, CheckCdfShape(output_shape, cdf.shape()));
261308

309+
if (debug_level_ > 0) {
310+
OP_REQUIRES_OK(context, CheckCdfValues(precision_, cdf));
311+
}
312+
262313
std::vector<int64> data_shape, cdf_shape;
263314
OP_REQUIRES_OK(
264315
context, MergeAxes(output_shape, cdf.shape(), &data_shape, &cdf_shape));
@@ -269,10 +320,12 @@ class RangeDecodeOp : public OpKernel {
269320
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
270321

271322
switch (data_shape.size()) {
272-
#define RANGE_DECODE_CASE(dim) \
273-
case dim: { \
274-
RangeDecodeImpl<dim>(output->flat<int16>(), data_shape, \
275-
cdf.flat_inner_dims<int32>(), cdf_shape, encoded); \
323+
#define RANGE_DECODE_CASE(dim) \
324+
case dim: { \
325+
OP_REQUIRES_OK( \
326+
context, RangeDecodeImpl<dim>(output->flat<int16>(), data_shape, \
327+
cdf.flat_inner_dims<int32>(), cdf_shape, \
328+
encoded)); \
276329
} break
277330
RANGE_DECODE_CASE(1);
278331
RANGE_DECODE_CASE(2);
@@ -291,11 +344,11 @@ class RangeDecodeOp : public OpKernel {
291344

292345
private:
293346
template <int N>
294-
void RangeDecodeImpl(TTypes<int16>::Flat output,
295-
gtl::ArraySlice<int64> output_shape,
296-
TTypes<int32>::ConstMatrix cdf,
297-
gtl::ArraySlice<int64> cdf_shape,
298-
const string& encoded) const {
347+
tensorflow::Status RangeDecodeImpl(TTypes<int16>::Flat output,
348+
gtl::ArraySlice<int64> output_shape,
349+
TTypes<int32>::ConstMatrix cdf,
350+
gtl::ArraySlice<int64> cdf_shape,
351+
const string& encoded) const {
299352
BroadcastRange<int16, int32, N> view{output.data(), output_shape,
300353
cdf.data(), cdf_shape};
301354

@@ -315,11 +368,13 @@ class RangeDecodeOp : public OpKernel {
315368
const int32* cdf_slice = pair.second;
316369
DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size);
317370

318-
*data = decoder.Decode(gtl::ArraySlice<int32>{cdf_slice, chip_size});
371+
*data = decoder.Decode({cdf_slice, chip_size});
319372
}
373+
return tensorflow::Status::OK();
320374
}
321375

322376
int precision_;
377+
int debug_level_;
323378
};
324379

325380
REGISTER_KERNEL_BUILDER(Name("RangeDecode").Device(DEVICE_CPU), RangeDecodeOp);

cc/kernels/range_coding_kernels_test.cc

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ limitations under the License.
3737
#include "tensorflow/core/platform/test.h"
3838
#include "tensorflow/core/public/session.h"
3939
#include "tensorflow/core/public/session_options.h"
40-
4140
#include "tensorflow_compression/cc/kernels/range_coder.h"
4241

4342
namespace tensorflow_compression {
@@ -100,10 +99,21 @@ class RangeCoderOpsTest : public OpsTestBase {
10099
protected:
101100
Status RunEncodeOp(int precision, gtl::ArraySlice<Tensor> input,
102101
Tensor* output) {
102+
return RunEncodeOpImpl(precision, input, 0, output);
103+
}
104+
105+
Status RunEncodeOpDebug(int precision, gtl::ArraySlice<Tensor> input,
106+
Tensor* output) {
107+
return RunEncodeOpImpl(precision, input, 1, output);
108+
}
109+
110+
Status RunEncodeOpImpl(int precision, gtl::ArraySlice<Tensor> input,
111+
int debug_level, Tensor* output) {
103112
TF_RETURN_IF_ERROR(NodeDefBuilder("encode", "RangeEncode")
104113
.Input(tensorflow::FakeInput(DT_INT16))
105114
.Input(tensorflow::FakeInput(DT_INT32))
106115
.Attr("precision", precision)
116+
.Attr("debug_level", debug_level)
107117
.Finalize(node_def()));
108118
TF_RETURN_IF_ERROR(InitOp());
109119

@@ -124,11 +134,22 @@ class RangeCoderOpsTest : public OpsTestBase {
124134

125135
Status RunDecodeOp(int precision, gtl::ArraySlice<Tensor> input,
126136
Tensor* output) {
137+
return RunDecodeOpImpl(precision, input, 0, output);
138+
}
139+
140+
Status RunDecodeOpDebug(int precision, gtl::ArraySlice<Tensor> input,
141+
Tensor* output) {
142+
return RunDecodeOpImpl(precision, input, 1, output);
143+
}
144+
145+
Status RunDecodeOpImpl(int precision, gtl::ArraySlice<Tensor> input,
146+
int debug_level, Tensor* output) {
127147
TF_RETURN_IF_ERROR(NodeDefBuilder("decode", "RangeDecode")
128148
.Input(tensorflow::FakeInput(DT_STRING))
129149
.Input(tensorflow::FakeInput(DT_INT32))
130150
.Input(tensorflow::FakeInput(DT_INT32))
131151
.Attr("precision", precision)
152+
.Attr("debug_level", debug_level)
132153
.Finalize(node_def()));
133154
TF_RETURN_IF_ERROR(InitOp());
134155

@@ -419,6 +440,70 @@ TEST_F(RangeCoderOpsTest, InvalidBroadcast) {
419440
}
420441
}
421442

443+
#define EXPECT_STATUS_SUBSTR(status_expr, message) \
444+
{ \
445+
auto status = (status_expr); \
446+
EXPECT_FALSE(status.ok()); \
447+
EXPECT_NE(status.error_message().find((message)), string::npos) \
448+
<< status.error_message(); \
449+
}
450+
451+
TEST_F(RangeCoderOpsTest, EncoderDebug) {
452+
Tensor data(DT_INT16, {});
453+
data.scalar<int16>()() = 1;
454+
455+
Tensor cdf(DT_INT32, {4});
456+
cdf.vec<int32>().setValues({0, 16, 18, 32});
457+
458+
Tensor unused;
459+
auto status = RunEncodeOpDebug(5, {data, cdf}, &unused);
460+
EXPECT_TRUE(status.ok());
461+
462+
data.scalar<int16>()() = -1;
463+
EXPECT_STATUS_SUBSTR(RunEncodeOpDebug(5, {data, cdf}, &unused),
464+
"value not in [0, 3)");
465+
466+
data.scalar<int16>()() = 5;
467+
EXPECT_STATUS_SUBSTR(RunEncodeOpDebug(5, {data, cdf}, &unused),
468+
"value not in [0, 3)");
469+
}
470+
471+
TEST_F(RangeCoderOpsTest, DecoderDebug) {
472+
RangeEncoder encoder(5);
473+
474+
string encoded_string;
475+
encoder.Encode(16, 18, &encoded_string);
476+
encoder.Finalize(&encoded_string);
477+
478+
Tensor encoded(DT_STRING, {});
479+
encoded.scalar<string>()().swap(encoded_string);
480+
481+
Tensor shape(DT_INT32, {0});
482+
483+
Tensor cdf(DT_INT32, {4});
484+
cdf.vec<int32>().setValues({0, 16, 18, 32});
485+
486+
Tensor unused;
487+
auto status = RunDecodeOpDebug(5, {encoded, shape, cdf}, &unused);
488+
EXPECT_TRUE(status.ok());
489+
490+
cdf.vec<int32>().setValues({1, 16, 18, 32});
491+
EXPECT_STATUS_SUBSTR(RunDecodeOpDebug(5, {encoded, shape, cdf}, &unused),
492+
"cdf[0]=1");
493+
494+
cdf.vec<int32>().setValues({0, 16, 18, 31});
495+
EXPECT_STATUS_SUBSTR(RunDecodeOpDebug(5, {encoded, shape, cdf}, &unused),
496+
"cdf[^1]=31");
497+
498+
cdf.vec<int32>().setValues({0, 18, 16, 32});
499+
EXPECT_STATUS_SUBSTR(RunDecodeOpDebug(5, {encoded, shape, cdf}, &unused),
500+
"monotonic");
501+
502+
cdf = Tensor(DT_INT32, {2});
503+
cdf.vec<int32>().setValues({0, 32});
504+
EXPECT_STATUS_SUBSTR(RunDecodeOpDebug(5, {encoded, shape, cdf}, &unused),
505+
"CDF size");
506+
}
422507
} // namespace
423508
} // namespace tensorflow_compression
424509

0 commit comments

Comments
 (0)