Skip to content

Commit 08578b4

Browse files
nicolemitchellcopybara-github
authored andcommitted
Adds run-length gamma encode and decode ops.
PiperOrigin-RevId: 431955768 Change-Id: I3cc6f58e85d2882e55493dc6f0142d2f787f7a1d
1 parent cbe6ce9 commit 08578b4

File tree

6 files changed

+757
-0
lines changed

6 files changed

+757
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/* Copyright 2022 Google LLC. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#define EIGEN_USE_THREADS
16+
17+
#include <algorithm>
18+
#include <array>
19+
#include <cmath>
20+
#include <cstdint>
21+
#include <cstring>
22+
#include <limits>
23+
#include <type_traits>
24+
#include <vector>
25+
26+
#include "absl/types/span.h"
27+
#include "tensorflow/core/framework/op_kernel.h"
28+
#include "tensorflow/core/framework/tensor.h"
29+
#include "tensorflow/core/framework/tensor_shape.h"
30+
#include "tensorflow/core/framework/tensor_types.h"
31+
#include "tensorflow/core/lib/core/errors.h"
32+
#include "tensorflow/core/lib/core/status.h"
33+
#include "tensorflow/core/platform/logging.h"
34+
#include "tensorflow/core/platform/macros.h"
35+
#include "tensorflow/core/platform/types.h"
36+
#include "tensorflow_compression/cc/lib/bit_coder.h"
37+
38+
namespace tensorflow_compression {
39+
namespace {
40+
namespace errors = tensorflow::errors;
41+
using tensorflow::DEVICE_CPU;
42+
using tensorflow::OpKernel;
43+
using tensorflow::OpKernelConstruction;
44+
using tensorflow::OpKernelContext;
45+
using tensorflow::string;
46+
using tensorflow::Tensor;
47+
using tensorflow::TensorShape;
48+
using tensorflow::TensorShapeUtils;
49+
using tensorflow::tstring;
50+
51+
class RunLengthGammaEncodeOp : public OpKernel {
52+
public:
53+
explicit RunLengthGammaEncodeOp(OpKernelConstruction* context)
54+
: OpKernel(context) {}
55+
56+
void Compute(OpKernelContext* context) override {
57+
const Tensor& data_tensor = context->input(0);
58+
auto data = data_tensor.flat<int32_t>();
59+
60+
Tensor* code_tensor;
61+
OP_REQUIRES_OK(context,
62+
context->allocate_output(0, TensorShape{}, &code_tensor));
63+
tstring* code = &code_tensor->scalar<tstring>()();
64+
65+
// Initialize bit encoder and ensure it allocates more than enough bits.
66+
// The maximum code length is achieved when there are no zeros in the input
67+
// array. The encoded size of each value is 2 + kMaxGammaBits (1 bit for
68+
// no leading zeros, 1 bit for sign and kMaxGammaBits for magnitude). If
69+
// any zeros were present in the input array, then the encoded size would be
70+
// strictly smaller by kMaxGammaBits and bigger by the difference in
71+
// encoding (the existing zero run length + 1).
72+
BitWriter enc;
73+
enc.Allocate(data.size() * (2 + enc.kMaxGammaBits));
74+
// Save number of zeros + 1 preceding next non-zero element.
75+
uint32_t zero_ct = 1;
76+
77+
// Iterate through data tensor.
78+
for (size_t i = 0; i < data.size(); i++) {
79+
// Increment zero count.
80+
if (data(i) == 0) {
81+
zero_ct += 1;
82+
} else {
83+
// Encode run length of zeros.
84+
enc.WriteGamma(zero_ct);
85+
// Encode sign of value.
86+
enc.WriteOneBit(data(i) > 0);
87+
// Encode magnitude of value.
88+
DCHECK_NE(data(i), std::numeric_limits<int32_t>::min());
89+
enc.WriteGamma(std::abs(data(i)));
90+
// Reset zero count (1 because Gamma cannot encode 0).
91+
zero_ct = 1;
92+
}
93+
}
94+
if (zero_ct > 1) {
95+
enc.WriteGamma(zero_ct);
96+
}
97+
98+
// Pad any remaining bits in last byte with 0.
99+
enc.ZeroPadToByte();
100+
// Write encoded bitstring to code.
101+
code->assign(enc.GetData(), enc.GetBytesWritten());
102+
}
103+
};
104+
105+
REGISTER_KERNEL_BUILDER(Name("RunLengthGammaEncode").Device(DEVICE_CPU),
106+
RunLengthGammaEncodeOp);
107+
108+
class RunLengthGammaDecodeOp : public OpKernel {
109+
public:
110+
explicit RunLengthGammaDecodeOp(OpKernelConstruction* context)
111+
: OpKernel(context) {}
112+
113+
void Compute(OpKernelContext* context) override {
114+
const Tensor& code_tensor = context->input(0);
115+
const Tensor& shape_tensor = context->input(1);
116+
117+
OP_REQUIRES(
118+
context, TensorShapeUtils::IsScalar(code_tensor.shape()),
119+
errors::InvalidArgument("Invalid `code` shape: ", code_tensor.shape()));
120+
OP_REQUIRES(context, TensorShapeUtils::IsVector(shape_tensor.shape()),
121+
errors::InvalidArgument("Invalid `shape` shape: ",
122+
shape_tensor.shape()));
123+
124+
const tstring& code = code_tensor.scalar<tstring>()();
125+
126+
TensorShape data_shape;
127+
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
128+
shape_tensor.vec<int32_t>(), &data_shape));
129+
Tensor* data_tensor;
130+
OP_REQUIRES_OK(context,
131+
context->allocate_output(0, data_shape, &data_tensor));
132+
auto data = data_tensor->flat<int32_t>();
133+
134+
// Initialize bit decoder to point at the code and expect code size bytes.
135+
BitReader dec(code);
136+
137+
// Fill data tensor with zeros.
138+
std::memset(data.data(), 0, data.size() * sizeof(data(0)));
139+
140+
for (size_t i = 0; i < data.size(); i++) {
141+
// Get number of zeros.
142+
uint32_t num_zeros = dec.ReadGamma();
143+
// Advance the index to the next non-zero element.
144+
i += num_zeros - 1;
145+
146+
// Account for case where the last element is zero.
147+
if (i == data.size()) {
148+
break;
149+
}
150+
// TODO(nicolemitchell): return error status instead of crashing
151+
DCHECK_LT(i, data.size());
152+
153+
// Get sign of value.
154+
uint32_t positive = dec.ReadOneBit();
155+
156+
// Get value.
157+
uint32_t value = dec.ReadGamma();
158+
159+
// Write value to data tensor element at index.
160+
DCHECK_LE(value, std::numeric_limits<int32_t>::max());
161+
data(i) = positive ? value : -static_cast<int32_t>(value);
162+
}
163+
164+
OP_REQUIRES(context, dec.Close().ok(),
165+
tensorflow::errors::DataLoss("Decoding error."));
166+
}
167+
};
168+
169+
REGISTER_KERNEL_BUILDER(Name("RunLengthGammaDecode").Device(DEVICE_CPU),
170+
RunLengthGammaDecodeOp);
171+
172+
} // namespace
173+
} // namespace tensorflow_compression

0 commit comments

Comments
 (0)