Skip to content

Commit c4c8622

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds Rice code and more general run length coding ops.
PiperOrigin-RevId: 484615653 Change-Id: I769d2cac08fa2afe16c5d59f1f6b7df574040dce
1 parent cf8bbca commit c4c8622

File tree

9 files changed

+732
-38
lines changed

9 files changed

+732
-38
lines changed

tensorflow_compression/cc/kernels/run_length_gamma_kernels.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,7 @@ class RunLengthGammaEncodeOp : public OpKernel {
6767
context->allocate_output(0, TensorShape{}, &code_tensor));
6868
tstring* code = &code_tensor->scalar<tstring>()();
6969

70-
// Initialize bit encoder and ensure it allocates more than enough bits.
71-
// The maximum code length is achieved when there are no zeros in the input
72-
// array. The encoded size of each value is 2 + kMaxGammaBits (1 bit for
73-
// no leading zeros, 1 bit for sign and kMaxGammaBits for magnitude). If
74-
// any zeros were present in the input array, then the encoded size would be
75-
// strictly smaller by kMaxGammaBits and bigger by the difference in
76-
// encoding (the existing zero run length + 1).
77-
BitWriter enc(data.size() * (2 + enc.kMaxGammaBits));
70+
BitWriter enc;
7871
// Save number of zeros + 1 preceding next non-zero element.
7972
uint32_t zero_ct = 1;
8073

tensorflow_compression/cc/kernels/run_length_gamma_kernels_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ TEST_F(BitCodingOpsTest, ManualEncodeWithBitcodingLibrary) {
194194
TF_ASSERT_OK(RunEncodeOp({data_tensor}, &code_tensor));
195195

196196
// Use bitcoding library to encode data.
197-
BitWriter enc_ = BitWriter(16);
197+
BitWriter enc_;
198198
enc_.WriteGamma(2); // one zero
199199
enc_.WriteOneBit(0); // negative
200200
enc_.WriteGamma(3); // 3
@@ -212,7 +212,7 @@ TEST_F(BitCodingOpsTest, ManualEncodeWithBitcodingLibrary) {
212212

213213
TEST_F(BitCodingOpsTest, ManualDecodeWithBitcodingLibrary) {
214214
// Use bitcoding library to manually encode [-3, 1, 0, 0] into code.
215-
BitWriter enc_ = BitWriter(16);
215+
BitWriter enc_;
216216
enc_.WriteGamma(1); // no zeros
217217
enc_.WriteOneBit(0); // negative
218218
enc_.WriteGamma(3); // 3
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/* Copyright 2022 Google LLC. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#define EIGEN_USE_THREADS
16+
17+
#include <algorithm>
18+
#include <array>
19+
#include <cassert>
20+
#include <cmath>
21+
#include <cstdint>
22+
#include <cstring>
23+
#include <limits>
24+
#include <type_traits>
25+
#include <vector>
26+
27+
#include "absl/types/span.h"
28+
#include "tensorflow/core/framework/op_kernel.h"
29+
#include "tensorflow/core/framework/tensor.h"
30+
#include "tensorflow/core/framework/tensor_shape.h"
31+
#include "tensorflow/core/framework/tensor_types.h"
32+
#include "tensorflow/core/platform/status.h"
33+
#include "tensorflow/core/platform/types.h"
34+
#include "tensorflow_compression/cc/lib/bit_coder.h"
35+
36+
namespace tensorflow_compression {
37+
namespace {
38+
namespace errors = tensorflow::errors;
39+
using tensorflow::DEVICE_CPU;
40+
using tensorflow::OpKernel;
41+
using tensorflow::OpKernelConstruction;
42+
using tensorflow::OpKernelContext;
43+
using tensorflow::string;
44+
using tensorflow::Tensor;
45+
using tensorflow::TensorShape;
46+
using tensorflow::TensorShapeUtils;
47+
using tensorflow::tstring;
48+
49+
#define OP_REQUIRES_OK_ABSL(context, status) \
50+
{ \
51+
auto s = (status); \
52+
OP_REQUIRES(context, s.ok(), tensorflow::Status( \
53+
static_cast<tensorflow::error::Code>(s.code()), s.message())); \
54+
}
55+
56+
// TODO(jonycgn): Try to avoid in-loop branches based on attributes.
57+
58+
class RunLengthEncodeOp : public OpKernel {
59+
public:
60+
explicit RunLengthEncodeOp(OpKernelConstruction* context)
61+
: OpKernel(context) {
62+
OP_REQUIRES_OK(context,
63+
context->GetAttr("run_length_code", &run_length_code_));
64+
OP_REQUIRES_OK(context,
65+
context->GetAttr("magnitude_code", &magnitude_code_));
66+
OP_REQUIRES_OK(context,
67+
context->GetAttr("use_run_length_for_non_zeros",
68+
&use_run_length_for_non_zeros_));
69+
}
70+
71+
inline void WriteRunLength(BitWriter& enc, const int32_t run_length) {
72+
if (run_length_code_ >= 0) {
73+
enc.WriteRice(run_length, run_length_code_);
74+
} else {
75+
enc.WriteGamma(run_length + 1);
76+
}
77+
}
78+
79+
inline void WriteNonZero(BitWriter& enc, const int32_t sample) {
80+
assert(sample != 0);
81+
const int32_t sign = sample > 0;
82+
enc.WriteOneBit(sign);
83+
if (magnitude_code_ >= 0) {
84+
enc.WriteRice(sign ? sample - 1 : -(sample + 1),
85+
magnitude_code_);
86+
} else {
87+
if (sample == std::numeric_limits<int32_t>::min()) {
88+
// We can't encode int32 minimum. Encode closest value instead.
89+
enc.WriteGamma(-(std::numeric_limits<int32_t>::min() + 1));
90+
} else {
91+
enc.WriteGamma(sign ? sample : -sample);
92+
}
93+
}
94+
}
95+
96+
void Compute(OpKernelContext* context) override {
97+
const Tensor& data_tensor = context->input(0);
98+
auto data = data_tensor.flat<int32_t>();
99+
100+
Tensor* code_tensor;
101+
OP_REQUIRES_OK(context,
102+
context->allocate_output(0, TensorShape{}, &code_tensor));
103+
tstring* code = &code_tensor->scalar<tstring>()();
104+
105+
BitWriter enc;
106+
107+
const int32_t* const end = data.data() + data.size();
108+
const int32_t* p = data.data();
109+
110+
while (p < end) {
111+
// Find next non-zero.
112+
const int32_t* q = std::find_if_not(p, end,
113+
[](int32_t x) { return x == 0; });
114+
WriteRunLength(enc, q - p);
115+
p = q;
116+
117+
if (!(p < end)) break;
118+
119+
if (use_run_length_for_non_zeros_) {
120+
// Find next zero.
121+
q = std::find_if(p, end, [](int32_t x) { return x == 0; });
122+
WriteRunLength(enc, q - p);
123+
while (p < q) {
124+
WriteNonZero(enc, *p++);
125+
}
126+
} else {
127+
WriteNonZero(enc, *p++);
128+
}
129+
}
130+
131+
// Write encoded bitstring to code.
132+
auto encoded = enc.GetData();
133+
code->assign(encoded.data(), encoded.size());
134+
}
135+
136+
private:
137+
int run_length_code_;
138+
int magnitude_code_;
139+
bool use_run_length_for_non_zeros_;
140+
};
141+
142+
REGISTER_KERNEL_BUILDER(Name("RunLengthEncode").Device(DEVICE_CPU),
143+
RunLengthEncodeOp);
144+
145+
class RunLengthDecodeOp : public OpKernel {
146+
public:
147+
explicit RunLengthDecodeOp(OpKernelConstruction* context)
148+
: OpKernel(context) {
149+
OP_REQUIRES_OK(context,
150+
context->GetAttr("run_length_code", &run_length_code_));
151+
OP_REQUIRES_OK(context,
152+
context->GetAttr("magnitude_code", &magnitude_code_));
153+
OP_REQUIRES_OK(context,
154+
context->GetAttr("use_run_length_for_non_zeros",
155+
&use_run_length_for_non_zeros_));
156+
}
157+
158+
inline absl::StatusOr<int32_t> ReadRunLength(OpKernelContext* context,
159+
BitReader& dec) {
160+
if (run_length_code_ >= 0) {
161+
return dec.ReadRice(run_length_code_);
162+
} else {
163+
auto gamma = dec.ReadGamma();
164+
if (!gamma.ok()) return gamma;
165+
return *gamma - 1;
166+
}
167+
}
168+
169+
inline absl::StatusOr<int32_t> ReadNonZero(OpKernelContext* context,
170+
BitReader& dec) {
171+
auto positive = dec.ReadOneBit();
172+
if (!positive.ok()) return positive;
173+
if (magnitude_code_ >= 0) {
174+
auto rice = dec.ReadRice(magnitude_code_);
175+
if (!rice.ok()) return rice;
176+
return *positive ? *rice + 1 : -*rice - 1;
177+
} else {
178+
auto gamma = dec.ReadGamma();
179+
if (!gamma.ok()) return gamma;
180+
return *positive ? *gamma : -*gamma;
181+
}
182+
}
183+
184+
void Compute(OpKernelContext* context) override {
185+
const Tensor& code_tensor = context->input(0);
186+
const Tensor& shape_tensor = context->input(1);
187+
188+
OP_REQUIRES(
189+
context, TensorShapeUtils::IsScalar(code_tensor.shape()),
190+
errors::InvalidArgument("Invalid `code` shape: ", code_tensor.shape()));
191+
OP_REQUIRES(context, TensorShapeUtils::IsVector(shape_tensor.shape()),
192+
errors::InvalidArgument("Invalid `shape` shape: ",
193+
shape_tensor.shape()));
194+
195+
const tstring& code = code_tensor.scalar<tstring>()();
196+
197+
TensorShape data_shape;
198+
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
199+
shape_tensor.vec<int32_t>(), &data_shape));
200+
Tensor* data_tensor;
201+
OP_REQUIRES_OK(context,
202+
context->allocate_output(0, data_shape, &data_tensor));
203+
auto data = data_tensor->flat<int32_t>();
204+
205+
// Initialize bit decoder to point at the code.
206+
BitReader dec(code);
207+
208+
// Fill data tensor with zeros.
209+
std::memset(data.data(), 0, data.size() * sizeof(data(0)));
210+
211+
int32_t* const end = data.data() + data.size();
212+
int32_t* p = data.data();
213+
214+
while (p < end) {
215+
// Skip to the next non-zero element.
216+
auto run_length = ReadRunLength(context, dec);
217+
OP_REQUIRES_OK_ABSL(context, run_length.status());
218+
219+
p += *run_length;
220+
221+
if (!(p < end)) {
222+
// Should not be past the last element.
223+
OP_REQUIRES(context, p == end,
224+
errors::DataLoss("Decoded past end of tensor."));
225+
break;
226+
}
227+
228+
if (use_run_length_for_non_zeros_) {
229+
run_length = ReadRunLength(context, dec);
230+
OP_REQUIRES_OK_ABSL(context, run_length.status());
231+
const int32_t* const next_zero = p + *run_length;
232+
OP_REQUIRES(context, next_zero <= end,
233+
errors::DataLoss("Decoded past end of tensor."));
234+
while (p < next_zero) {
235+
auto nonzero = ReadNonZero(context, dec);
236+
OP_REQUIRES_OK_ABSL(context, nonzero.status());
237+
*p++ = *nonzero;
238+
}
239+
} else {
240+
auto nonzero = ReadNonZero(context, dec);
241+
OP_REQUIRES_OK_ABSL(context, nonzero.status());
242+
*p++ = *nonzero;
243+
}
244+
}
245+
}
246+
247+
private:
248+
int run_length_code_;
249+
int magnitude_code_;
250+
bool use_run_length_for_non_zeros_;
251+
};
252+
253+
REGISTER_KERNEL_BUILDER(Name("RunLengthDecode").Device(DEVICE_CPU),
254+
RunLengthDecodeOp);
255+
256+
#undef OP_REQUIRES_OK_ABSL
257+
258+
} // namespace
259+
} // namespace tensorflow_compression

0 commit comments

Comments
 (0)