Skip to content

Commit d9dbba5

Browse files
author
jballe
committed
Creates unbounded index range coder.
PiperOrigin-RevId: 241829940
1 parent aa565b9 commit d9dbba5

File tree

3 files changed

+805
-0
lines changed

3 files changed

+805
-0
lines changed
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
/* Copyright 2019 Google LLC. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#define EIGEN_USE_THREADS
17+
18+
#include <algorithm>
19+
#include <array>
20+
#include <limits>
21+
#include <type_traits>
22+
#include <vector>
23+
24+
#include "tensorflow/core/framework/op_kernel.h"
25+
#include "tensorflow/core/framework/tensor.h"
26+
#include "tensorflow/core/framework/tensor_shape.h"
27+
#include "tensorflow/core/framework/tensor_types.h"
28+
#include "tensorflow/core/lib/core/errors.h"
29+
#include "tensorflow/core/lib/core/status.h"
30+
#include "tensorflow/core/lib/gtl/array_slice.h"
31+
#include "tensorflow/core/platform/logging.h"
32+
#include "tensorflow/core/platform/macros.h"
33+
#include "tensorflow/core/platform/types.h"
34+
35+
#include "tensorflow_compression/cc/kernels/range_coder.h"
36+
37+
namespace tensorflow_compression {
38+
namespace {
39+
namespace errors = tensorflow::errors;
40+
namespace gtl = tensorflow::gtl;
41+
using tensorflow::DEVICE_CPU;
42+
using tensorflow::OpKernel;
43+
using tensorflow::OpKernelConstruction;
44+
using tensorflow::OpKernelContext;
45+
using tensorflow::Status;
46+
using tensorflow::Tensor;
47+
using tensorflow::TensorShape;
48+
using tensorflow::TensorShapeUtils;
49+
using tensorflow::TTypes;
50+
51+
// Non-incremental encoder op -------------------------------------------------
52+
class UnboundedIndexRangeEncodeOp : public OpKernel {
53+
public:
54+
explicit UnboundedIndexRangeEncodeOp(OpKernelConstruction* context)
55+
: OpKernel(context) {
56+
OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
57+
OP_REQUIRES_OK(context,
58+
context->GetAttr("overflow_width", &overflow_width_));
59+
OP_REQUIRES(context, 0 < precision_ && precision_ <= 16,
60+
errors::InvalidArgument("`precision` must be in [1, 16]: ",
61+
precision_));
62+
OP_REQUIRES(
63+
context, 0 < overflow_width_ && overflow_width_ <= precision_,
64+
errors::InvalidArgument("`overflow_width` must be in [1, precision]: ",
65+
overflow_width_));
66+
}
67+
68+
void Compute(OpKernelContext* context) override {
69+
const Tensor& data = context->input(0);
70+
const Tensor& index = context->input(1);
71+
const Tensor& cdf = context->input(2);
72+
const Tensor& cdf_size = context->input(3);
73+
const Tensor& offset = context->input(4);
74+
75+
OP_REQUIRES(context, data.shape() == index.shape(),
76+
errors::InvalidArgument(
77+
"`data` and `index` should have the same shape"));
78+
79+
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(cdf.shape()),
80+
errors::InvalidArgument("`cdf` should be 2-D."));
81+
OP_REQUIRES(
82+
context,
83+
TensorShapeUtils::IsVector(cdf_size.shape()) &&
84+
cdf_size.dim_size(0) == cdf.dim_size(0),
85+
errors::InvalidArgument("`cdf_size` should be 1-D and its length "
86+
"should match the number of rows in `cdf`."));
87+
OP_REQUIRES(
88+
context,
89+
TensorShapeUtils::IsVector(offset.shape()) &&
90+
offset.dim_size(0) == cdf.dim_size(0),
91+
errors::InvalidArgument("`offset` should be 1-D and its length "
92+
"should match the number of rows in `cdf`."));
93+
94+
Tensor* output;
95+
OP_REQUIRES_OK(context,
96+
context->allocate_output(0, TensorShape{}, &output));
97+
98+
RangeEncodeImpl(data.flat<int32>(), index.flat<int32>(),
99+
cdf.matrix<int32>(), cdf_size.vec<int32>(),
100+
offset.vec<int32>(), &output->flat<string>()(0));
101+
}
102+
103+
private:
104+
void RangeEncodeImpl(TTypes<int32>::ConstFlat data,
105+
TTypes<int32>::ConstFlat index,
106+
TTypes<int32>::ConstMatrix cdf,
107+
TTypes<int32>::ConstVec cdf_size,
108+
TTypes<int32>::ConstVec offset, string* output) const {
109+
RangeEncoder encoder{precision_};
110+
111+
DCHECK_GE(cdf.dimension(1), 2);
112+
DCHECK_LE(cdf.dimension(1), std::numeric_limits<int16>::max());
113+
DCHECK_EQ(cdf.dimension(0), cdf_size.size());
114+
115+
const uint32 max_overflow = (1 << overflow_width_) - 1;
116+
const uint32 overflow_shift = precision_ - overflow_width_;
117+
118+
const int64 data_size = data.size();
119+
for (int64 i = 0; i < data_size; ++i) {
120+
const int32 cdf_index = index(i);
121+
122+
DCHECK_GE(cdf_index, 0);
123+
DCHECK_LT(cdf_index, cdf.dimension(0));
124+
125+
const int32 max_value = cdf_size(cdf_index) - 2;
126+
DCHECK_GE(max_value, 0);
127+
DCHECK_LT(max_value + 1, cdf.dimension(1));
128+
129+
int32 value = data(i);
130+
// Map values with tracked probabilities to 0..max_value range.
131+
value -= offset(cdf_index);
132+
// If outside of this range, map value to non-negative integer overflow.
133+
uint32 overflow;
134+
if (value < 0) {
135+
overflow = -2 * value - 1;
136+
value = max_value;
137+
} else if (value >= max_value) {
138+
overflow = 2 * (value - max_value);
139+
value = max_value;
140+
}
141+
142+
const int32* cdf_slice = &cdf(cdf_index, 0);
143+
encoder.Encode(cdf_slice[value], cdf_slice[value + 1], output);
144+
145+
// Encode overflow using variable length code.
146+
if (value == max_value) {
147+
int32 widths = 0;
148+
while (overflow >> (widths * overflow_width_)) {
149+
++widths;
150+
}
151+
uint32 val = widths;
152+
while (val >= max_overflow) {
153+
encoder.Encode(max_overflow << overflow_shift,
154+
(max_overflow + 1) << overflow_shift, output);
155+
val -= max_overflow;
156+
}
157+
encoder.Encode(val << overflow_shift, (val + 1) << overflow_shift,
158+
output);
159+
for (int32 j = 0; j < widths; ++j) {
160+
const uint32 val = (overflow >> (j * overflow_width_)) & max_overflow;
161+
encoder.Encode(val << overflow_shift, (val + 1) << overflow_shift,
162+
output);
163+
}
164+
}
165+
}
166+
encoder.Finalize(output);
167+
}
168+
169+
int precision_;
170+
int overflow_width_;
171+
};
172+
173+
REGISTER_KERNEL_BUILDER(Name("UnboundedIndexRangeEncode").Device(DEVICE_CPU),
174+
UnboundedIndexRangeEncodeOp);
175+
176+
// Non-incremental decoder op -------------------------------------------------
177+
class UnboundedIndexRangeDecodeOp : public OpKernel {
178+
public:
179+
explicit UnboundedIndexRangeDecodeOp(OpKernelConstruction* context)
180+
: OpKernel(context) {
181+
OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
182+
OP_REQUIRES_OK(context,
183+
context->GetAttr("overflow_width", &overflow_width_));
184+
OP_REQUIRES(context, 0 < precision_ && precision_ <= 16,
185+
errors::InvalidArgument("`precision` must be in [1, 16]: ",
186+
precision_));
187+
}
188+
189+
void Compute(OpKernelContext* context) override {
190+
const Tensor& encoded = context->input(0);
191+
const Tensor& index = context->input(1);
192+
const Tensor& cdf = context->input(2);
193+
const Tensor& cdf_size = context->input(3);
194+
const Tensor& offset = context->input(4);
195+
196+
OP_REQUIRES(context, encoded.shape() == TensorShape{},
197+
errors::InvalidArgument("Invalid `encoded` shape: ",
198+
encoded.shape().DebugString()));
199+
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(cdf.shape()),
200+
errors::InvalidArgument("`cdf` should be 2-D."));
201+
OP_REQUIRES(
202+
context,
203+
TensorShapeUtils::IsVector(cdf_size.shape()) &&
204+
cdf_size.dim_size(0) == cdf.dim_size(0),
205+
errors::InvalidArgument("`cdf_size` should be 1-D and its length "
206+
"should match the number of rows in `cdf`."));
207+
OP_REQUIRES(
208+
context,
209+
TensorShapeUtils::IsVector(offset.shape()) &&
210+
offset.dim_size(0) == cdf.dim_size(0),
211+
errors::InvalidArgument("`offset` should be 1-D and its length "
212+
"should match the number of rows in `cdf`."));
213+
214+
Tensor* output;
215+
OP_REQUIRES_OK(context,
216+
context->allocate_output(0, index.shape(), &output));
217+
218+
OP_REQUIRES_OK(
219+
context, RangeDecodeImpl(output->flat<int32>(), index.flat<int32>(),
220+
cdf.matrix<int32>(), cdf_size.vec<int32>(),
221+
offset.vec<int32>(), encoded.flat<string>()));
222+
}
223+
224+
private:
225+
tensorflow::Status RangeDecodeImpl(TTypes<int32>::Flat output,
226+
TTypes<int32>::ConstFlat index,
227+
TTypes<int32>::ConstMatrix cdf,
228+
TTypes<int32>::ConstVec cdf_size,
229+
TTypes<int32>::ConstVec offset,
230+
TTypes<string>::ConstFlat encoded) const {
231+
RangeDecoder decoder{encoded(0), precision_};
232+
233+
DCHECK_GE(cdf.dimension(1), 2);
234+
DCHECK_LE(cdf.dimension(1), std::numeric_limits<int16>::max());
235+
236+
const uint32 max_overflow = (1 << overflow_width_) - 1;
237+
const int32 overflow_cdf_size = (1 << overflow_width_) + 1;
238+
std::vector<int32> overflow_cdf(overflow_cdf_size);
239+
for (int32 i = 0; i < overflow_cdf_size; ++i) {
240+
overflow_cdf[i] = i << (precision_ - overflow_width_);
241+
}
242+
243+
const int64 output_size = output.size();
244+
for (int64 i = 0; i < output_size; ++i) {
245+
const int32 cdf_index = index(i);
246+
247+
DCHECK_GE(cdf_index, 0);
248+
DCHECK_LT(cdf_index, cdf.dimension(0));
249+
250+
const int32 max_value = cdf_size(cdf_index) - 2;
251+
DCHECK_GE(max_value, 0);
252+
DCHECK_LT(max_value + 1, cdf.dimension(1));
253+
254+
const int32* cdf_slice = &cdf(cdf_index, 0);
255+
int32 value =
256+
decoder.Decode(gtl::ArraySlice<int32>(cdf_slice, max_value + 2));
257+
258+
// Decode overflow using variable length code.
259+
if (value == max_value) {
260+
int32 widths = 0;
261+
uint32 val;
262+
do {
263+
val = decoder.Decode(overflow_cdf);
264+
widths += val;
265+
} while (val == max_overflow);
266+
uint32 overflow = 0;
267+
for (int32 j = 0; j < widths; ++j) {
268+
const uint32 val = decoder.Decode(overflow_cdf);
269+
DCHECK_LE(val, max_overflow);
270+
overflow |= val << (j * overflow_width_);
271+
}
272+
// Map positive values back to integer values.
273+
value = overflow >> 1;
274+
if (overflow & 1) {
275+
value = -value - 1;
276+
} else {
277+
value += max_value;
278+
}
279+
}
280+
281+
// Map values in 0..max_range range back to original integer range.
282+
value += offset(cdf_index);
283+
284+
output(i) = value;
285+
}
286+
287+
return tensorflow::Status::OK();
288+
}
289+
290+
int precision_;
291+
int overflow_width_;
292+
};
293+
294+
REGISTER_KERNEL_BUILDER(Name("UnboundedIndexRangeDecode").Device(DEVICE_CPU),
295+
UnboundedIndexRangeDecodeOp);
296+
297+
} // namespace
298+
} // namespace tensorflow_compression

0 commit comments

Comments
 (0)