@@ -16,6 +16,7 @@ limitations under the License.
16
16
#include " tensorflow_compression/cc/kernels/range_coder_kernels.h"
17
17
18
18
#include < cstdint>
19
+ #include < memory>
19
20
#include < string>
20
21
#include < utility>
21
22
@@ -114,24 +115,9 @@ Status IndexCDFMatrix(const TTypes<int32_t>::ConstMatrix& table,
114
115
115
116
class RangeEncoderInterface final : public EntropyEncoderInterface {
116
117
public:
117
- static Status MakeShared (const Tensor lookup,
118
- std::shared_ptr<EntropyEncoderInterface>* ptr) {
119
- Status status;
120
- RangeEncoderInterface* re = new RangeEncoderInterface (lookup);
121
- if (lookup.dims () == 1 ) {
122
- status = IndexCDFVector (lookup.flat <int32_t >(), &re->lookup_ );
123
- } else if (lookup.dims () == 2 ) {
124
- status = IndexCDFMatrix (lookup.matrix <int32_t >(), &re->lookup_ );
125
- } else {
126
- status = errors::InvalidArgument (" `lookup` must be rank 1 or 2." );
127
- }
128
- if (status.ok ()) {
129
- ptr->reset (re);
130
- } else {
131
- delete re;
132
- }
133
- return status;
134
- }
118
+ RangeEncoderInterface (absl::Span<const absl::Span<const int32_t >> lookup,
119
+ Tensor hold)
120
+ : lookup_(lookup.begin(), lookup.end()), hold_(std::move(hold)) {}
135
121
136
122
Status Encode (int32_t index, int32_t value) override {
137
123
TF_RETURN_IF_ERROR (CheckInRange (" index" , index, 0 , lookup_.size ()));
@@ -153,8 +139,6 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {
153
139
}
154
140
155
141
private:
156
- explicit RangeEncoderInterface (Tensor lookup) : hold_(std::move(lookup)) {}
157
-
158
142
void OverflowEncode (const absl::Span<const int32_t > row, int32_t value) {
159
143
const int32_t max_value = row.size () - 3 ;
160
144
const int32_t sign = value < 0 ;
@@ -193,24 +177,12 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {
193
177
194
178
class RangeDecoderInterface final : public EntropyDecoderInterface {
195
179
public:
196
- static Status MakeShared (absl::string_view encoded, const Tensor lookup,
197
- std::shared_ptr<EntropyDecoderInterface>* ptr) {
198
- Status status;
199
- RangeDecoderInterface* rd = new RangeDecoderInterface (encoded, lookup);
200
- if (lookup.dims () == 1 ) {
201
- status = IndexCDFVector (lookup.flat <int32_t >(), &rd->lookup_ );
202
- } else if (lookup.dims () == 2 ) {
203
- status = IndexCDFMatrix (lookup.matrix <int32_t >(), &rd->lookup_ );
204
- } else {
205
- status = errors::InvalidArgument (" `lookup` must be rank 1 or 2." );
206
- }
207
- if (status.ok ()) {
208
- ptr->reset (rd);
209
- } else {
210
- delete rd;
211
- }
212
- return status;
213
- }
180
+ RangeDecoderInterface (absl::string_view encoded,
181
+ absl::Span<const absl::Span<const int32_t >> lookup,
182
+ Tensor hold)
183
+ : lookup_(lookup.begin(), lookup.end()),
184
+ decoder_ (encoded),
185
+ hold_(std::move(hold)) {}
214
186
215
187
Status Decode (int32_t index, int32_t * output) override {
216
188
TF_RETURN_IF_ERROR (CheckInRange (" index" , index, 0 , lookup_.size ()));
@@ -232,9 +204,6 @@ class RangeDecoderInterface final : public EntropyDecoderInterface {
232
204
}
233
205
234
206
private:
235
- RangeDecoderInterface (absl::string_view encoded, Tensor lookup)
236
- : decoder_(encoded), hold_(std::move(lookup)) {}
237
-
238
207
int32_t OverflowDecode (const absl::Span<const int32_t > row) {
239
208
constexpr int32_t binary_uniform_cdf[] = {0 , 1 , 2 };
240
209
const int32_t max_value = row.size () - 3 ;
@@ -313,11 +282,21 @@ class CreateRangeEncoderOp : public tensorflow::OpKernel {
313
282
context->allocate_output (0 , handle_shape, &output_tensor));
314
283
315
284
const Tensor& lookup = context->input (1 );
285
+ OP_REQUIRES (context, lookup.dims () == 1 || lookup.dims () == 2 ,
286
+ errors::InvalidArgument (" `lookup` must be rank 1 or 2." ));
287
+
288
+ std::vector<absl::Span<const int32_t >> table;
289
+ if (lookup.dims () == 1 ) {
290
+ OP_REQUIRES_OK (context, IndexCDFVector (lookup.flat <int32_t >(), &table));
291
+ } else {
292
+ DCHECK_EQ (lookup.dims (), 2 );
293
+ OP_REQUIRES_OK (context, IndexCDFMatrix (lookup.matrix <int32_t >(), &table));
294
+ }
295
+
316
296
auto output = output_tensor->flat <Variant>();
317
297
for (int64_t i = 0 ; i < output.size (); ++i) {
318
298
EntropyEncoderVariant wrap;
319
- OP_REQUIRES_OK (context,
320
- RangeEncoderInterface::MakeShared (lookup, &wrap.encoder ));
299
+ wrap.encoder = std::make_shared<RangeEncoderInterface>(table, lookup);
321
300
output (i) = std::move (wrap);
322
301
}
323
302
}
@@ -388,10 +367,10 @@ class EntropyEncodeChannelOp : public tensorflow::OpKernel {
388
367
context->SetStatus (status); \
389
368
return ; \
390
369
}
391
- #define REQUIRES_OK (status ) \
392
- { \
393
- auto s = (status); \
394
- REQUIRES (s.ok (), s); \
370
+ #define REQUIRES_OK (status ) \
371
+ { \
372
+ auto s = (status); \
373
+ REQUIRES (s.ok (), s); \
395
374
}
396
375
397
376
const int64_t num_elements = value.dimension (1 );
@@ -484,10 +463,10 @@ class EntropyEncodeIndexOp : public tensorflow::OpKernel {
484
463
context->SetStatus (status); \
485
464
return ; \
486
465
}
487
- #define REQUIRES_OK (status ) \
488
- { \
489
- auto s = (status); \
490
- REQUIRES (s.ok (), s); \
466
+ #define REQUIRES_OK (status ) \
467
+ { \
468
+ auto s = (status); \
469
+ REQUIRES (s.ok (), s); \
491
470
}
492
471
493
472
const int64_t num_elements = value.dimension (1 );
@@ -560,11 +539,22 @@ class CreateRangeDecoderOp : public tensorflow::OpKernel {
560
539
&output_tensor));
561
540
562
541
const Tensor& lookup = context->input (1 );
542
+ OP_REQUIRES (context, lookup.dims () == 1 || lookup.dims () == 2 ,
543
+ errors::InvalidArgument (" `lookup` must be rank 1 or 2." ));
544
+
545
+ std::vector<absl::Span<const int32_t >> table;
546
+ if (lookup.dims () == 1 ) {
547
+ OP_REQUIRES_OK (context, IndexCDFVector (lookup.flat <int32_t >(), &table));
548
+ } else {
549
+ DCHECK_EQ (lookup.dims (), 2 );
550
+ OP_REQUIRES_OK (context, IndexCDFMatrix (lookup.matrix <int32_t >(), &table));
551
+ }
552
+
563
553
auto output = output_tensor->flat <Variant>();
564
554
for (int64_t i = 0 ; i < output.size (); ++i) {
565
555
EntropyDecoderVariant wrap;
566
- OP_REQUIRES_OK (context, RangeDecoderInterface::MakeShared (
567
- encoded (i), lookup, &wrap. decoder ) );
556
+ wrap. decoder =
557
+ std::make_shared<RangeDecoderInterface>( encoded (i), table, lookup );
568
558
wrap.holder = encoded_tensor;
569
559
output (i) = std::move (wrap);
570
560
}
@@ -636,10 +626,10 @@ class EntropyDecodeChannelOp : public tensorflow::OpKernel {
636
626
context->SetStatus (status); \
637
627
return ; \
638
628
}
639
- #define REQUIRES_OK (status ) \
640
- { \
641
- auto s = (status); \
642
- REQUIRES (s.ok (), s); \
629
+ #define REQUIRES_OK (status ) \
630
+ { \
631
+ auto s = (status); \
632
+ REQUIRES (s.ok (), s); \
643
633
}
644
634
645
635
const int64_t num_elements = output.dimension (1 );
@@ -736,10 +726,10 @@ class EntropyDecodeIndexOp : public tensorflow::OpKernel {
736
726
context->SetStatus (status); \
737
727
return ; \
738
728
}
739
- #define REQUIRES_OK (status ) \
740
- { \
741
- auto s = (status); \
742
- REQUIRES (s.ok (), s); \
729
+ #define REQUIRES_OK (status ) \
730
+ { \
731
+ auto s = (status); \
732
+ REQUIRES (s.ok (), s); \
743
733
}
744
734
745
735
const int64_t num_elements = output.dimension (1 );
0 commit comments