Skip to content

Commit 3b70187

Browse files
ssjhvcopybara-github
authored andcommitted
Improved CreateRangeEncoder/Decoder ops to reduce cdf scanning when there are
large number of batch items. PiperOrigin-RevId: 448377181 Change-Id: Icc440c0af2825b2c9ef81b7209685b498fdfff2a
1 parent 7bd5363 commit 3b70187

File tree

1 file changed

+51
-61
lines changed

1 file changed

+51
-61
lines changed

tensorflow_compression/cc/kernels/range_coder_kernels.cc

Lines changed: 51 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "tensorflow_compression/cc/kernels/range_coder_kernels.h"
1717

1818
#include <cstdint>
19+
#include <memory>
1920
#include <string>
2021
#include <utility>
2122

@@ -114,24 +115,9 @@ Status IndexCDFMatrix(const TTypes<int32_t>::ConstMatrix& table,
114115

115116
class RangeEncoderInterface final : public EntropyEncoderInterface {
116117
public:
117-
static Status MakeShared(const Tensor lookup,
118-
std::shared_ptr<EntropyEncoderInterface>* ptr) {
119-
Status status;
120-
RangeEncoderInterface* re = new RangeEncoderInterface(lookup);
121-
if (lookup.dims() == 1) {
122-
status = IndexCDFVector(lookup.flat<int32_t>(), &re->lookup_);
123-
} else if (lookup.dims() == 2) {
124-
status = IndexCDFMatrix(lookup.matrix<int32_t>(), &re->lookup_);
125-
} else {
126-
status = errors::InvalidArgument("`lookup` must be rank 1 or 2.");
127-
}
128-
if (status.ok()) {
129-
ptr->reset(re);
130-
} else {
131-
delete re;
132-
}
133-
return status;
134-
}
118+
RangeEncoderInterface(absl::Span<const absl::Span<const int32_t>> lookup,
119+
Tensor hold)
120+
: lookup_(lookup.begin(), lookup.end()), hold_(std::move(hold)) {}
135121

136122
Status Encode(int32_t index, int32_t value) override {
137123
TF_RETURN_IF_ERROR(CheckInRange("index", index, 0, lookup_.size()));
@@ -153,8 +139,6 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {
153139
}
154140

155141
private:
156-
explicit RangeEncoderInterface(Tensor lookup) : hold_(std::move(lookup)) {}
157-
158142
void OverflowEncode(const absl::Span<const int32_t> row, int32_t value) {
159143
const int32_t max_value = row.size() - 3;
160144
const int32_t sign = value < 0;
@@ -193,24 +177,12 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {
193177

194178
class RangeDecoderInterface final : public EntropyDecoderInterface {
195179
public:
196-
static Status MakeShared(absl::string_view encoded, const Tensor lookup,
197-
std::shared_ptr<EntropyDecoderInterface>* ptr) {
198-
Status status;
199-
RangeDecoderInterface* rd = new RangeDecoderInterface(encoded, lookup);
200-
if (lookup.dims() == 1) {
201-
status = IndexCDFVector(lookup.flat<int32_t>(), &rd->lookup_);
202-
} else if (lookup.dims() == 2) {
203-
status = IndexCDFMatrix(lookup.matrix<int32_t>(), &rd->lookup_);
204-
} else {
205-
status = errors::InvalidArgument("`lookup` must be rank 1 or 2.");
206-
}
207-
if (status.ok()) {
208-
ptr->reset(rd);
209-
} else {
210-
delete rd;
211-
}
212-
return status;
213-
}
180+
RangeDecoderInterface(absl::string_view encoded,
181+
absl::Span<const absl::Span<const int32_t>> lookup,
182+
Tensor hold)
183+
: lookup_(lookup.begin(), lookup.end()),
184+
decoder_(encoded),
185+
hold_(std::move(hold)) {}
214186

215187
Status Decode(int32_t index, int32_t* output) override {
216188
TF_RETURN_IF_ERROR(CheckInRange("index", index, 0, lookup_.size()));
@@ -232,9 +204,6 @@ class RangeDecoderInterface final : public EntropyDecoderInterface {
232204
}
233205

234206
private:
235-
RangeDecoderInterface(absl::string_view encoded, Tensor lookup)
236-
: decoder_(encoded), hold_(std::move(lookup)) {}
237-
238207
int32_t OverflowDecode(const absl::Span<const int32_t> row) {
239208
constexpr int32_t binary_uniform_cdf[] = {0, 1, 2};
240209
const int32_t max_value = row.size() - 3;
@@ -313,11 +282,21 @@ class CreateRangeEncoderOp : public tensorflow::OpKernel {
313282
context->allocate_output(0, handle_shape, &output_tensor));
314283

315284
const Tensor& lookup = context->input(1);
285+
OP_REQUIRES(context, lookup.dims() == 1 || lookup.dims() == 2,
286+
errors::InvalidArgument("`lookup` must be rank 1 or 2."));
287+
288+
std::vector<absl::Span<const int32_t>> table;
289+
if (lookup.dims() == 1) {
290+
OP_REQUIRES_OK(context, IndexCDFVector(lookup.flat<int32_t>(), &table));
291+
} else {
292+
DCHECK_EQ(lookup.dims(), 2);
293+
OP_REQUIRES_OK(context, IndexCDFMatrix(lookup.matrix<int32_t>(), &table));
294+
}
295+
316296
auto output = output_tensor->flat<Variant>();
317297
for (int64_t i = 0; i < output.size(); ++i) {
318298
EntropyEncoderVariant wrap;
319-
OP_REQUIRES_OK(context,
320-
RangeEncoderInterface::MakeShared(lookup, &wrap.encoder));
299+
wrap.encoder = std::make_shared<RangeEncoderInterface>(table, lookup);
321300
output(i) = std::move(wrap);
322301
}
323302
}
@@ -388,10 +367,10 @@ class EntropyEncodeChannelOp : public tensorflow::OpKernel {
388367
context->SetStatus(status); \
389368
return; \
390369
}
391-
#define REQUIRES_OK(status) \
392-
{ \
393-
auto s = (status); \
394-
REQUIRES(s.ok(), s); \
370+
#define REQUIRES_OK(status) \
371+
{ \
372+
auto s = (status); \
373+
REQUIRES(s.ok(), s); \
395374
}
396375

397376
const int64_t num_elements = value.dimension(1);
@@ -484,10 +463,10 @@ class EntropyEncodeIndexOp : public tensorflow::OpKernel {
484463
context->SetStatus(status); \
485464
return; \
486465
}
487-
#define REQUIRES_OK(status) \
488-
{ \
489-
auto s = (status); \
490-
REQUIRES(s.ok(), s); \
466+
#define REQUIRES_OK(status) \
467+
{ \
468+
auto s = (status); \
469+
REQUIRES(s.ok(), s); \
491470
}
492471

493472
const int64_t num_elements = value.dimension(1);
@@ -560,11 +539,22 @@ class CreateRangeDecoderOp : public tensorflow::OpKernel {
560539
&output_tensor));
561540

562541
const Tensor& lookup = context->input(1);
542+
OP_REQUIRES(context, lookup.dims() == 1 || lookup.dims() == 2,
543+
errors::InvalidArgument("`lookup` must be rank 1 or 2."));
544+
545+
std::vector<absl::Span<const int32_t>> table;
546+
if (lookup.dims() == 1) {
547+
OP_REQUIRES_OK(context, IndexCDFVector(lookup.flat<int32_t>(), &table));
548+
} else {
549+
DCHECK_EQ(lookup.dims(), 2);
550+
OP_REQUIRES_OK(context, IndexCDFMatrix(lookup.matrix<int32_t>(), &table));
551+
}
552+
563553
auto output = output_tensor->flat<Variant>();
564554
for (int64_t i = 0; i < output.size(); ++i) {
565555
EntropyDecoderVariant wrap;
566-
OP_REQUIRES_OK(context, RangeDecoderInterface::MakeShared(
567-
encoded(i), lookup, &wrap.decoder));
556+
wrap.decoder =
557+
std::make_shared<RangeDecoderInterface>(encoded(i), table, lookup);
568558
wrap.holder = encoded_tensor;
569559
output(i) = std::move(wrap);
570560
}
@@ -636,10 +626,10 @@ class EntropyDecodeChannelOp : public tensorflow::OpKernel {
636626
context->SetStatus(status); \
637627
return; \
638628
}
639-
#define REQUIRES_OK(status) \
640-
{ \
641-
auto s = (status); \
642-
REQUIRES(s.ok(), s); \
629+
#define REQUIRES_OK(status) \
630+
{ \
631+
auto s = (status); \
632+
REQUIRES(s.ok(), s); \
643633
}
644634

645635
const int64_t num_elements = output.dimension(1);
@@ -736,10 +726,10 @@ class EntropyDecodeIndexOp : public tensorflow::OpKernel {
736726
context->SetStatus(status); \
737727
return; \
738728
}
739-
#define REQUIRES_OK(status) \
740-
{ \
741-
auto s = (status); \
742-
REQUIRES(s.ok(), s); \
729+
#define REQUIRES_OK(status) \
730+
{ \
731+
auto s = (status); \
732+
REQUIRES(s.ok(), s); \
743733
}
744734

745735
const int64_t num_elements = output.dimension(1);

0 commit comments

Comments
 (0)