@@ -31,7 +31,6 @@ limitations under the License.
31
31
#include " tensorflow/core/platform/logging.h"
32
32
#include " tensorflow/core/platform/macros.h"
33
33
#include " tensorflow/core/platform/types.h"
34
-
35
34
#include " tensorflow_compression/cc/kernels/range_coder.h"
36
35
#include " tensorflow_compression/cc/kernels/range_coding_kernels_util.h"
37
36
@@ -40,21 +39,21 @@ namespace {
40
39
namespace errors = tensorflow::errors;
41
40
namespace gtl = tensorflow::gtl;
42
41
using tensorflow::DEVICE_CPU;
42
+ using tensorflow::int16;
43
+ using tensorflow::int32;
44
+ using tensorflow::int64;
43
45
using tensorflow::OpKernel;
44
46
using tensorflow::OpKernelConstruction;
45
47
using tensorflow::OpKernelContext;
46
48
using tensorflow::Status;
49
+ using tensorflow::string;
47
50
using tensorflow::Tensor;
48
51
using tensorflow::TensorShape;
49
52
using tensorflow::TensorShapeUtils;
50
53
using tensorflow::TTypes;
51
- using tensorflow::int16;
52
- using tensorflow::int32;
53
- using tensorflow::int64;
54
- using tensorflow::string;
55
- using tensorflow::uint8;
56
54
using tensorflow::uint32;
57
55
using tensorflow::uint64;
56
+ using tensorflow::uint8;
58
57
59
58
// A helper class to iterate over data and cdf simultaneously, while cdf is
60
59
// broadcasted to data.
@@ -151,14 +150,42 @@ Status CheckCdfShape(const TensorShape& data_shape,
151
150
return Status::OK ();
152
151
}
153
152
154
- // Non-incremental encoder op -------------------------------------------------
153
+ tensorflow::Status CheckCdfValues (int precision,
154
+ const tensorflow::Tensor& cdf_tensor) {
155
+ const auto cdf = cdf_tensor.flat_inner_dims <int32, 2 >();
156
+ const auto size = cdf.dimension (1 );
157
+ if (size <= 2 ) {
158
+ return errors::InvalidArgument (" CDF size should be > 2: " , size);
159
+ }
160
+
161
+ const int32 upper_bound = 1 << precision;
162
+ for (int64 i = 0 ; i < cdf.dimension (0 ); ++i) {
163
+ auto slice = tensorflow::gtl::ArraySlice<int32>(&cdf (i, 0 ), size);
164
+ if (slice[0 ] != 0 || slice[size - 1 ] != upper_bound) {
165
+ return errors::InvalidArgument (" CDF should start from 0 and end at " ,
166
+ upper_bound, " : cdf[0]=" , slice[0 ],
167
+ " , cdf[^1]=" , slice[size - 1 ]);
168
+ }
169
+ for (int64 j = 0 ; j + 1 < size; ++j) {
170
+ if (slice[j + 1 ] <= slice[j]) {
171
+ return errors::InvalidArgument (" CDF is not monotonic" );
172
+ }
173
+ }
174
+ }
175
+ return tensorflow::Status::OK ();
176
+ }
177
+
155
178
class RangeEncodeOp : public OpKernel {
156
179
public:
157
180
explicit RangeEncodeOp (OpKernelConstruction* context) : OpKernel(context) {
158
181
OP_REQUIRES_OK (context, context->GetAttr (" precision" , &precision_));
159
182
OP_REQUIRES (context, 0 < precision_ && precision_ <= 16 ,
160
183
errors::InvalidArgument (" `precision` must be in [1, 16]: " ,
161
184
precision_));
185
+ OP_REQUIRES_OK (context, context->GetAttr (" debug_level" , &debug_level_));
186
+ OP_REQUIRES (context, debug_level_ == 0 || debug_level_ == 1 ,
187
+ errors::InvalidArgument (" `debug_level` must be 0 or 1: " ,
188
+ debug_level_));
162
189
}
163
190
164
191
void Compute (OpKernelContext* context) override {
@@ -167,6 +194,10 @@ class RangeEncodeOp : public OpKernel {
167
194
168
195
OP_REQUIRES_OK (context, CheckCdfShape (data.shape (), cdf.shape ()));
169
196
197
+ if (debug_level_ > 0 ) {
198
+ OP_REQUIRES_OK (context, CheckCdfValues (precision_, cdf));
199
+ }
200
+
170
201
std::vector<int64> data_shape, cdf_shape;
171
202
OP_REQUIRES_OK (
172
203
context, MergeAxes (data.shape (), cdf.shape (), &data_shape, &cdf_shape));
@@ -177,10 +208,12 @@ class RangeEncodeOp : public OpKernel {
177
208
string* output = &output_tensor->scalar <string>()();
178
209
179
210
switch (data_shape.size ()) {
180
- #define RANGE_ENCODE_CASE (dims ) \
181
- case dims: { \
182
- RangeEncodeImpl<dims>(data.flat <int16>(), data_shape, \
183
- cdf.flat_inner_dims <int32, 2 >(), cdf_shape, output); \
211
+ #define RANGE_ENCODE_CASE (dims ) \
212
+ case dims: { \
213
+ OP_REQUIRES_OK (context, \
214
+ RangeEncodeImpl<dims>(data.flat <int16>(), data_shape, \
215
+ cdf.flat_inner_dims <int32, 2 >(), \
216
+ cdf_shape, output)); \
184
217
} break
185
218
RANGE_ENCODE_CASE (1 );
186
219
RANGE_ENCODE_CASE (2 );
@@ -199,10 +232,11 @@ class RangeEncodeOp : public OpKernel {
199
232
200
233
private:
201
234
template <int N>
202
- void RangeEncodeImpl (TTypes<int16>::ConstFlat data,
203
- gtl::ArraySlice<int64> data_shape,
204
- TTypes<int32>::ConstMatrix cdf,
205
- gtl::ArraySlice<int64> cdf_shape, string* output) const {
235
+ tensorflow::Status RangeEncodeImpl (TTypes<int16>::ConstFlat data,
236
+ gtl::ArraySlice<int64> data_shape,
237
+ TTypes<int32>::ConstMatrix cdf,
238
+ gtl::ArraySlice<int64> cdf_shape,
239
+ string* output) const {
206
240
const int64 data_size = data.size ();
207
241
const int64 cdf_size = cdf.size ();
208
242
const int64 chip_size = cdf.dimension (1 );
@@ -214,8 +248,15 @@ class RangeEncodeOp : public OpKernel {
214
248
const auto pair = view.Next ();
215
249
216
250
const int64 index = *pair.first ;
217
- DCHECK_GE (index, 0 );
218
- DCHECK_LT (index + 1 , chip_size);
251
+ if (debug_level_ > 0 ) {
252
+ if (index < 0 || chip_size <= index + 1 ) {
253
+ return errors::InvalidArgument (" 'data' value not in [0, " ,
254
+ chip_size - 1 , " ): value=" , index);
255
+ }
256
+ } else {
257
+ DCHECK_GE (index, 0 );
258
+ DCHECK_LT (index + 1 , chip_size);
259
+ }
219
260
220
261
const int32* cdf_slice = pair.second ;
221
262
DCHECK_LE (cdf_slice + chip_size, cdf.data () + cdf_size);
@@ -226,21 +267,26 @@ class RangeEncodeOp : public OpKernel {
226
267
}
227
268
228
269
encoder.Finalize (output);
270
+ return tensorflow::Status::OK ();
229
271
}
230
272
231
273
int precision_;
274
+ int debug_level_;
232
275
};
233
276
234
277
REGISTER_KERNEL_BUILDER (Name(" RangeEncode" ).Device(DEVICE_CPU), RangeEncodeOp);
235
278
236
- // Non-incremental decoder op -------------------------------------------------
237
279
class RangeDecodeOp : public OpKernel {
238
280
public:
239
281
explicit RangeDecodeOp (OpKernelConstruction* context) : OpKernel(context) {
240
282
OP_REQUIRES_OK (context, context->GetAttr (" precision" , &precision_));
241
283
OP_REQUIRES (context, 0 < precision_ && precision_ <= 16 ,
242
284
errors::InvalidArgument (" `precision` must be in [1, 16]: " ,
243
285
precision_));
286
+ OP_REQUIRES_OK (context, context->GetAttr (" debug_level" , &debug_level_));
287
+ OP_REQUIRES (context, debug_level_ == 0 || debug_level_ == 1 ,
288
+ errors::InvalidArgument (" `debug_level` must be 0 or 1: " ,
289
+ debug_level_));
244
290
}
245
291
246
292
void Compute (OpKernelContext* context) override {
@@ -254,11 +300,16 @@ class RangeDecodeOp : public OpKernel {
254
300
OP_REQUIRES (context, TensorShapeUtils::IsVector (shape.shape ()),
255
301
errors::InvalidArgument (" Invalid `shape` shape: " ,
256
302
shape.shape ().DebugString ()));
303
+
257
304
TensorShape output_shape;
258
305
OP_REQUIRES_OK (context, TensorShapeUtils::MakeShape (shape.vec <int32>(),
259
306
&output_shape));
260
307
OP_REQUIRES_OK (context, CheckCdfShape (output_shape, cdf.shape ()));
261
308
309
+ if (debug_level_ > 0 ) {
310
+ OP_REQUIRES_OK (context, CheckCdfValues (precision_, cdf));
311
+ }
312
+
262
313
std::vector<int64> data_shape, cdf_shape;
263
314
OP_REQUIRES_OK (
264
315
context, MergeAxes (output_shape, cdf.shape (), &data_shape, &cdf_shape));
@@ -269,10 +320,12 @@ class RangeDecodeOp : public OpKernel {
269
320
OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
270
321
271
322
switch (data_shape.size ()) {
272
- #define RANGE_DECODE_CASE (dim ) \
273
- case dim: { \
274
- RangeDecodeImpl<dim>(output->flat <int16>(), data_shape, \
275
- cdf.flat_inner_dims <int32>(), cdf_shape, encoded); \
323
+ #define RANGE_DECODE_CASE (dim ) \
324
+ case dim: { \
325
+ OP_REQUIRES_OK ( \
326
+ context, RangeDecodeImpl<dim>(output->flat <int16>(), data_shape, \
327
+ cdf.flat_inner_dims <int32>(), cdf_shape, \
328
+ encoded)); \
276
329
} break
277
330
RANGE_DECODE_CASE (1 );
278
331
RANGE_DECODE_CASE (2 );
@@ -291,11 +344,11 @@ class RangeDecodeOp : public OpKernel {
291
344
292
345
private:
293
346
template <int N>
294
- void RangeDecodeImpl (TTypes<int16>::Flat output,
295
- gtl::ArraySlice<int64> output_shape,
296
- TTypes<int32>::ConstMatrix cdf,
297
- gtl::ArraySlice<int64> cdf_shape,
298
- const string& encoded) const {
347
+ tensorflow::Status RangeDecodeImpl (TTypes<int16>::Flat output,
348
+ gtl::ArraySlice<int64> output_shape,
349
+ TTypes<int32>::ConstMatrix cdf,
350
+ gtl::ArraySlice<int64> cdf_shape,
351
+ const string& encoded) const {
299
352
BroadcastRange<int16, int32, N> view{output.data (), output_shape,
300
353
cdf.data (), cdf_shape};
301
354
@@ -315,11 +368,13 @@ class RangeDecodeOp : public OpKernel {
315
368
const int32* cdf_slice = pair.second ;
316
369
DCHECK_LE (cdf_slice + chip_size, cdf.data () + cdf_size);
317
370
318
- *data = decoder.Decode (gtl::ArraySlice<int32> {cdf_slice, chip_size});
371
+ *data = decoder.Decode ({cdf_slice, chip_size});
319
372
}
373
+ return tensorflow::Status::OK ();
320
374
}
321
375
322
376
int precision_;
377
+ int debug_level_;
323
378
};
324
379
325
380
REGISTER_KERNEL_BUILDER (Name(" RangeDecode" ).Device(DEVICE_CPU), RangeDecodeOp);
0 commit comments