Skip to content

Commit 5391ab5

Browse files
GooglerSung Jin Hwang
authored andcommitted
Project import generated by Copybara.
PiperOrigin-RevId: 245129345 Change-Id: Ib979ca7a93ae98bffe571a244ec78ac40b59a7fc
1 parent f1d9824 commit 5391ab5

File tree

6 files changed

+73
-89
lines changed

6 files changed

+73
-89
lines changed

cc/kernels/range_coder.cc

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ limitations under the License.
1919
// a digitised message", presented to the Video & Data Recording Conference,
2020
// held in Southampton, July 24-27, 1979.
2121
//
22+
#include "tensorflow_compression/cc/kernels/range_coder.h"
23+
2224
#include <limits>
2325
#include <string>
2426

2527
#include "tensorflow/core/lib/gtl/array_slice.h"
2628
#include "tensorflow/core/platform/logging.h"
2729
#include "tensorflow/core/platform/types.h"
2830

29-
#include "tensorflow_compression/cc/kernels/range_coder.h"
30-
3131
namespace tensorflow_compression {
3232
namespace gtl = tensorflow::gtl;
3333
using tensorflow::int32;
@@ -36,16 +36,16 @@ using tensorflow::uint32;
3636
using tensorflow::uint64;
3737
using tensorflow::uint8;
3838

39-
RangeEncoder::RangeEncoder(int precision) : precision_(precision) {
40-
CHECK_GT(precision, 0);
41-
CHECK_LE(precision, 16);
42-
}
39+
void RangeEncoder::Encode(int32 lower, int32 upper, int precision,
40+
string* sink) {
41+
// Input requirement: 0 < precision < 16.
42+
DCHECK_GT(precision, 0);
43+
DCHECK_LE(precision, 16);
4344

44-
void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) {
4545
// Input requirement: 0 <= lower < upper <= 2^precision.
4646
DCHECK_LE(0, lower);
4747
DCHECK_LT(lower, upper);
48-
DCHECK_LE(upper, 1 << precision_);
48+
DCHECK_LE(upper, 1 << precision);
4949

5050
// `base` and `size` represent a half-open interval [base, base + size).
5151
// Loop invariant: 2^16 <= size <= 2^32.
@@ -69,8 +69,8 @@ void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) {
6969
// NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u`
7070
// can be rewritten as `(size - 1) * u + u` and all the computation can be
7171
// done in 32-bit mode. If 32-bit multiply is faster, then rewrite.
72-
const uint32 a = (size * static_cast<uint64>(lower)) >> precision_;
73-
const uint32 b = ((size * static_cast<uint64>(upper)) >> precision_) - 1;
72+
const uint32 a = (size * static_cast<uint64>(lower)) >> precision;
73+
const uint32 b = ((size * static_cast<uint64>(upper)) >> precision) - 1;
7474
DCHECK_LE(a, b);
7575

7676
// Let's confirm the RHS of a, b fit in uint32 type.
@@ -301,23 +301,21 @@ void RangeEncoder::Finalize(string* sink) {
301301
delay_ = 0;
302302
}
303303

304-
RangeDecoder::RangeDecoder(const string& source, int precision)
305-
: current_(source.begin()),
306-
begin_(source.begin()),
307-
end_(source.end()),
308-
precision_(precision) {
309-
CHECK_LE(precision, 16);
310-
304+
RangeDecoder::RangeDecoder(const string& source)
305+
: current_(source.begin()), end_(source.end()) {
311306
Read16BitValue();
312307
Read16BitValue();
313308
}
314309

315-
int32 RangeDecoder::Decode(gtl::ArraySlice<int32> cdf) {
310+
int32 RangeDecoder::Decode(gtl::ArraySlice<int32> cdf, int precision) {
311+
// Input requirement: 0 < precision < 16.
312+
DCHECK_GT(precision, 0);
313+
DCHECK_LE(precision, 16);
314+
316315
const uint64 size = static_cast<uint64>(size_minus1_) + 1;
317316
const uint64 offset =
318-
((static_cast<uint64>(value_ - base_) + 1) << precision_) - 1;
317+
((static_cast<uint64>(value_ - base_) + 1) << precision) - 1;
319318

320-
// This is similar to std::lower_range() with std::less_equal as comparison.
321319
// After the binary search, `pv` points to the smallest number v that
322320
// satisfies offset < (size * v) / 2^precision.
323321

@@ -333,7 +331,7 @@ int32 RangeDecoder::Decode(gtl::ArraySlice<int32> cdf) {
333331
const auto half = len / 2;
334332
const int32* mid = pv + half;
335333
DCHECK_GE(*mid, 0);
336-
DCHECK_LE(*mid, 1 << precision_);
334+
DCHECK_LE(*mid, 1 << precision);
337335
if (size * static_cast<uint64>(*mid) <= offset) {
338336
pv = mid + 1;
339337
len -= half + 1;
@@ -349,10 +347,10 @@ int32 RangeDecoder::Decode(gtl::ArraySlice<int32> cdf) {
349347
// cdf.size() - 2 instead and give up detecting this error.
350348
CHECK_LT(pv, cdf.data() + cdf.size());
351349

352-
const uint32 a = (size * static_cast<uint64>(*(pv - 1))) >> precision_;
353-
const uint32 b = ((size * static_cast<uint64>(*pv)) >> precision_) - 1;
354-
DCHECK_LE(a, offset >> precision_);
355-
DCHECK_LE(offset >> precision_, b);
350+
const uint32 a = (size * static_cast<uint64>(*(pv - 1))) >> precision;
351+
const uint32 b = ((size * static_cast<uint64>(*pv)) >> precision) - 1;
352+
DCHECK_LE(a, offset >> precision);
353+
DCHECK_LE(offset >> precision, b);
356354

357355
base_ += a;
358356
size_minus1_ = b - a;
@@ -378,5 +376,4 @@ void RangeDecoder::Read16BitValue() {
378376
value_ |= static_cast<uint8>(*current_++);
379377
}
380378
}
381-
382379
} // namespace tensorflow_compression

cc/kernels/range_coder.h

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@ namespace tensorflow_compression {
2626

2727
class RangeEncoder {
2828
public:
29-
// `precision` determines the granularity of probability masses passed to
30-
// Encode() function below.
31-
//
32-
// REQUIRES: 0 < precision <= 16.
33-
explicit RangeEncoder(int precision);
29+
RangeEncoder() = default;
3430

3531
// Encodes a half-open interval [lower / 2^precision, upper / 2^precision).
3632
// Suppose each character to be encoded is from an integer-valued
@@ -47,7 +43,8 @@ class RangeEncoder {
4743
// ...
4844
//
4945
// REQUIRES: 0 <= lower < upper <= 2^precision.
50-
void Encode(tensorflow::int32 lower, tensorflow::int32 upper,
46+
// REQUIRES: 0 < precision <= 16.
47+
void Encode(tensorflow::int32 lower, tensorflow::int32 upper, int precision,
5148
tensorflow::string* sink);
5249

5350
// The encode may contain some under-determined values from previous encoding.
@@ -60,18 +57,13 @@ class RangeEncoder {
6057
tensorflow::uint32 size_minus1_ =
6158
std::numeric_limits<tensorflow::uint32>::max();
6259
tensorflow::uint64 delay_ = 0;
63-
64-
const int precision_;
6560
};
6661

6762
class RangeDecoder {
6863
public:
6964
// Holds a reference to `source`. The caller has to make sure that `source`
7065
// outlives the decoder object.
71-
//
72-
// REQUIRES: `precision` must be the same as the encoder's precision.
73-
// REQUIRES: 0 < precision <= 16.
74-
RangeDecoder(const tensorflow::string& source, int precision);
66+
explicit RangeDecoder(const tensorflow::string& source);
7567

7668
// Decodes a character from `source` using CDF. The size of `cdf` should be
7769
// one more than the number of the character in the alphabet.
@@ -90,9 +82,11 @@ class RangeDecoder {
9082
// REQUIRES: cdf.size() > 1.
9183
// REQUIRES: cdf[i] <= cdf[i + 1] for i = 0, 1, ..., cdf.size() - 2.
9284
// REQUIRES: cdf[cdf.size() - 1] <= 2^precision.
85+
// REQUIRES: 0 < precision <= 16.
9386
//
9487
// In practice the last element of `cdf` should equal to 2^precision.
95-
tensorflow::int32 Decode(tensorflow::gtl::ArraySlice<tensorflow::int32> cdf);
88+
tensorflow::int32 Decode(tensorflow::gtl::ArraySlice<tensorflow::int32> cdf,
89+
int precision);
9690

9791
private:
9892
void Read16BitValue();
@@ -103,10 +97,7 @@ class RangeDecoder {
10397
tensorflow::uint32 value_ = 0;
10498

10599
tensorflow::string::const_iterator current_;
106-
const tensorflow::string::const_iterator begin_;
107100
const tensorflow::string::const_iterator end_;
108-
109-
const int precision_;
110101
};
111102

112103
} // namespace tensorflow_compression

cc/kernels/range_coder_test.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ void RangeEncodeDecodeTest(int precision, random::SimplePhilox* gen) {
6767
ideal_code_length[i] = -std::log2((cdf[i + 1] - cdf[i]) / normalizer);
6868
}
6969

70-
RangeEncoder encoder(precision);
70+
RangeEncoder encoder;
7171
string encoded;
7272
double ideal_length = 0.0;
7373
for (uint8 x : data) {
74-
encoder.Encode(cdf[x], cdf[x + 1], &encoded);
74+
encoder.Encode(cdf[x], cdf[x + 1], precision, &encoded);
7575
ideal_length += ideal_code_length[x];
7676
}
7777
encoder.Finalize(&encoded);
@@ -82,9 +82,9 @@ void RangeEncodeDecodeTest(int precision, random::SimplePhilox* gen) {
8282
<< " (ideal compression rate " << ideal_length / (8 * data.size())
8383
<< ")";
8484

85-
RangeDecoder decoder(encoded, precision);
85+
RangeDecoder decoder(encoded);
8686
for (int i = 0; i < data.size(); ++i) {
87-
const int32 decoded = decoder.Decode(cdf);
87+
const int32 decoded = decoder.Decode(cdf, precision);
8888
ASSERT_EQ(decoded, static_cast<int32>(data[i])) << i;
8989
}
9090
}
@@ -110,12 +110,12 @@ TEST(RangeCoderTest, FinalizeState0) {
110110
constexpr int kPrecision = 2;
111111

112112
string output;
113-
RangeEncoder encoder(kPrecision);
114-
encoder.Encode(0, 2, &output);
113+
RangeEncoder encoder;
114+
encoder.Encode(0, 2, kPrecision, &output);
115115
encoder.Finalize(&output);
116116

117-
RangeDecoder decoder(output, kPrecision);
118-
EXPECT_EQ(decoder.Decode({0, 2, 4}), 0);
117+
RangeDecoder decoder(output);
118+
EXPECT_EQ(decoder.Decode({0, 2, 4}, kPrecision), 0);
119119
}
120120

121121
} // namespace

cc/kernels/range_coding_kernels.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ class RangeEncodeOp : public OpKernel {
243243

244244
BroadcastRange<const int16, int32, N> view{data.data(), data_shape,
245245
cdf.data(), cdf_shape};
246-
RangeEncoder encoder{precision_};
246+
RangeEncoder encoder;
247247
for (int64 linear = 0; linear < data_size; ++linear) {
248248
const auto pair = view.Next();
249249

@@ -263,7 +263,7 @@ class RangeEncodeOp : public OpKernel {
263263

264264
const int32 lower = cdf_slice[index];
265265
const int32 upper = cdf_slice[index + 1];
266-
encoder.Encode(lower, upper, output);
266+
encoder.Encode(lower, upper, precision_, output);
267267
}
268268

269269
encoder.Finalize(output);
@@ -352,7 +352,7 @@ class RangeDecodeOp : public OpKernel {
352352
BroadcastRange<int16, int32, N> view{output.data(), output_shape,
353353
cdf.data(), cdf_shape};
354354

355-
RangeDecoder decoder{encoded, precision_};
355+
RangeDecoder decoder(encoded);
356356

357357
const int64 output_size = output.size();
358358
const int64 cdf_size = cdf.size();
@@ -368,7 +368,7 @@ class RangeDecodeOp : public OpKernel {
368368
const int32* cdf_slice = pair.second;
369369
DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size);
370370

371-
*data = decoder.Decode({cdf_slice, chip_size});
371+
*data = decoder.Decode({cdf_slice, chip_size}, precision_);
372372
}
373373
return tensorflow::Status::OK();
374374
}

cc/kernels/range_coding_kernels_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,10 @@ TEST_F(RangeCoderOpsTest, EncoderDebug) {
469469
}
470470

471471
TEST_F(RangeCoderOpsTest, DecoderDebug) {
472-
RangeEncoder encoder(5);
472+
RangeEncoder encoder;
473473

474474
string encoded_string;
475-
encoder.Encode(16, 18, &encoded_string);
475+
encoder.Encode(16, 18, 5, &encoded_string);
476476
encoder.Finalize(&encoded_string);
477477

478478
Tensor encoded(DT_STRING, {});

0 commit comments

Comments
 (0)