@@ -19,15 +19,15 @@ limitations under the License.
19
19
// a digitised message", presented to the Video & Data Recording Conference,
20
20
// held in Southampton, July 24-27, 1979.
21
21
//
22
+ #include " tensorflow_compression/cc/kernels/range_coder.h"
23
+
22
24
#include < limits>
23
25
#include < string>
24
26
25
27
#include " tensorflow/core/lib/gtl/array_slice.h"
26
28
#include " tensorflow/core/platform/logging.h"
27
29
#include " tensorflow/core/platform/types.h"
28
30
29
- #include " tensorflow_compression/cc/kernels/range_coder.h"
30
-
31
31
namespace tensorflow_compression {
32
32
namespace gtl = tensorflow::gtl;
33
33
using tensorflow::int32;
@@ -36,16 +36,16 @@ using tensorflow::uint32;
36
36
using tensorflow::uint64;
37
37
using tensorflow::uint8;
38
38
39
- RangeEncoder::RangeEncoder (int precision) : precision_(precision) {
40
- CHECK_GT (precision, 0 );
41
- CHECK_LE (precision, 16 );
42
- }
39
+ void RangeEncoder::Encode (int32 lower, int32 upper, int precision,
40
+ string* sink) {
41
+ // Input requirement: 0 < precision < 16.
42
+ DCHECK_GT (precision, 0 );
43
+ DCHECK_LE (precision, 16 );
43
44
44
- void RangeEncoder::Encode (int32 lower, int32 upper, string* sink) {
45
45
// Input requirement: 0 <= lower < upper <= 2^precision.
46
46
DCHECK_LE (0 , lower);
47
47
DCHECK_LT (lower, upper);
48
- DCHECK_LE (upper, 1 << precision_ );
48
+ DCHECK_LE (upper, 1 << precision );
49
49
50
50
// `base` and `size` represent a half-open interval [base, base + size).
51
51
// Loop invariant: 2^16 <= size <= 2^32.
@@ -69,8 +69,8 @@ void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) {
69
69
// NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u`
70
70
// can be rewritten as `(size - 1) * u + u` and all the computation can be
71
71
// done in 32-bit mode. If 32-bit multiply is faster, then rewrite.
72
- const uint32 a = (size * static_cast <uint64>(lower)) >> precision_ ;
73
- const uint32 b = ((size * static_cast <uint64>(upper)) >> precision_ ) - 1 ;
72
+ const uint32 a = (size * static_cast <uint64>(lower)) >> precision ;
73
+ const uint32 b = ((size * static_cast <uint64>(upper)) >> precision ) - 1 ;
74
74
DCHECK_LE (a, b);
75
75
76
76
// Let's confirm the RHS of a, b fit in uint32 type.
@@ -301,23 +301,21 @@ void RangeEncoder::Finalize(string* sink) {
301
301
delay_ = 0 ;
302
302
}
303
303
304
- RangeDecoder::RangeDecoder (const string& source, int precision)
305
- : current_(source.begin()),
306
- begin_ (source.begin()),
307
- end_(source.end()),
308
- precision_(precision) {
309
- CHECK_LE (precision, 16 );
310
-
304
+ RangeDecoder::RangeDecoder (const string& source)
305
+ : current_(source.begin()), end_(source.end()) {
311
306
Read16BitValue ();
312
307
Read16BitValue ();
313
308
}
314
309
315
- int32 RangeDecoder::Decode (gtl::ArraySlice<int32> cdf) {
310
+ int32 RangeDecoder::Decode (gtl::ArraySlice<int32> cdf, int precision) {
311
+ // Input requirement: 0 < precision < 16.
312
+ DCHECK_GT (precision, 0 );
313
+ DCHECK_LE (precision, 16 );
314
+
316
315
const uint64 size = static_cast <uint64>(size_minus1_) + 1 ;
317
316
const uint64 offset =
318
- ((static_cast <uint64>(value_ - base_) + 1 ) << precision_ ) - 1 ;
317
+ ((static_cast <uint64>(value_ - base_) + 1 ) << precision ) - 1 ;
319
318
320
- // This is similar to std::lower_range() with std::less_equal as comparison.
321
319
// After the binary search, `pv` points to the smallest number v that
322
320
// satisfies offset < (size * v) / 2^precision.
323
321
@@ -333,7 +331,7 @@ int32 RangeDecoder::Decode(gtl::ArraySlice<int32> cdf) {
333
331
const auto half = len / 2 ;
334
332
const int32* mid = pv + half;
335
333
DCHECK_GE (*mid, 0 );
336
- DCHECK_LE (*mid, 1 << precision_ );
334
+ DCHECK_LE (*mid, 1 << precision );
337
335
if (size * static_cast <uint64>(*mid) <= offset) {
338
336
pv = mid + 1 ;
339
337
len -= half + 1 ;
@@ -349,10 +347,10 @@ int32 RangeDecoder::Decode(gtl::ArraySlice<int32> cdf) {
349
347
// cdf.size() - 2 instead and give up detecting this error.
350
348
CHECK_LT (pv, cdf.data () + cdf.size ());
351
349
352
- const uint32 a = (size * static_cast <uint64>(*(pv - 1 ))) >> precision_ ;
353
- const uint32 b = ((size * static_cast <uint64>(*pv)) >> precision_ ) - 1 ;
354
- DCHECK_LE (a, offset >> precision_ );
355
- DCHECK_LE (offset >> precision_ , b);
350
+ const uint32 a = (size * static_cast <uint64>(*(pv - 1 ))) >> precision ;
351
+ const uint32 b = ((size * static_cast <uint64>(*pv)) >> precision ) - 1 ;
352
+ DCHECK_LE (a, offset >> precision );
353
+ DCHECK_LE (offset >> precision , b);
356
354
357
355
base_ += a;
358
356
size_minus1_ = b - a;
@@ -378,5 +376,4 @@ void RangeDecoder::Read16BitValue() {
378
376
value_ |= static_cast <uint8>(*current_++);
379
377
}
380
378
}
381
-
382
379
} // namespace tensorflow_compression
0 commit comments