@@ -19,8 +19,6 @@ limitations under the License.
19
19
#include < string>
20
20
#include < utility>
21
21
22
- #include " absl/base/integral_types.h"
23
- #include " tensorflow_compression/cc/lib/range_coder.h"
24
22
#include " absl/strings/string_view.h"
25
23
#include " absl/strings/substitute.h"
26
24
#include " absl/types/span.h"
@@ -36,6 +34,7 @@ limitations under the License.
36
34
#include " tensorflow/core/framework/variant_tensor_data.h"
37
35
#include " tensorflow/core/lib/core/errors.h"
38
36
#include " tensorflow/core/lib/core/threadpool.h"
37
+ #include " tensorflow_compression/cc/lib/range_coder.h"
39
38
40
39
namespace tensorflow_compression {
41
40
namespace {
@@ -367,11 +366,12 @@ class EntropyEncodeChannelOp : public tensorflow::OpKernel {
367
366
tensorflow::thread::ThreadPool* workers =
368
367
context->device ()->tensorflow_cpu_worker_threads ()->workers ;
369
368
tensorflow::mutex mu;
370
- workers->ParallelFor (
371
- handle.size (), cost_per_unit,
372
- [&handle, &mu, context, value, index_stride](int64 start, int64 limit) {
373
- PerShard (handle, value, index_stride, context, &mu, start, limit);
374
- });
369
+ workers->ParallelFor (handle.size (), cost_per_unit,
370
+ [&handle, &mu, context, value, index_stride](
371
+ int64_t start, int64_t limit) {
372
+ PerShard (handle, value, index_stride, context, &mu,
373
+ start, limit);
374
+ });
375
375
376
376
context->set_output (0 , handle_tensor);
377
377
}
@@ -382,19 +382,17 @@ class EntropyEncodeChannelOp : public tensorflow::OpKernel {
382
382
const int64_t index_stride,
383
383
tensorflow::OpKernelContext* context,
384
384
tensorflow::mutex* mu, int64_t start, int64_t limit) {
385
- #define REQUIRES_OK (status ) \
386
- if (auto s = (status); ABSL_PREDICT_FALSE (!s.ok ())) { \
387
- tensorflow::mutex_lock lock (*mu); \
388
- context->SetStatus (s); \
389
- return ; \
390
- }
391
-
392
385
#define REQUIRES (cond, status ) \
393
386
if (!ABSL_PREDICT_TRUE (cond)) { \
394
387
tensorflow::mutex_lock lock (*mu); \
395
388
context->SetStatus (status); \
396
389
return ; \
397
390
}
391
+ #define REQUIRES_OK (status ) \
392
+ { \
393
+ auto s = (status); \
394
+ REQUIRES (s.ok (), s); \
395
+ }
398
396
399
397
const int64_t num_elements = value.dimension (1 );
400
398
auto * p_value = &value (start, 0 );
@@ -467,7 +465,7 @@ class EntropyEncodeIndexOp : public tensorflow::OpKernel {
467
465
tensorflow::mutex mu;
468
466
workers->ParallelFor (
469
467
handle.size (), cost_per_unit,
470
- [&handle, &mu, context, value, index](int64 start, int64 limit) {
468
+ [&handle, &mu, context, value, index](int64_t start, int64_t limit) {
471
469
PerShard (handle, index, value, context, &mu, start, limit);
472
470
});
473
471
@@ -480,19 +478,17 @@ class EntropyEncodeIndexOp : public tensorflow::OpKernel {
480
478
TTypes<int32_t >::ConstMatrix value,
481
479
tensorflow::OpKernelContext* context,
482
480
tensorflow::mutex* mu, int64_t start, int64_t limit) {
483
- #define REQUIRES_OK (status ) \
484
- if (auto s = (status); ABSL_PREDICT_FALSE (!s.ok ())) { \
485
- tensorflow::mutex_lock lock (*mu); \
486
- context->SetStatus (s); \
487
- return ; \
488
- }
489
-
490
481
#define REQUIRES (cond, status ) \
491
482
if (!ABSL_PREDICT_TRUE (cond)) { \
492
483
tensorflow::mutex_lock lock (*mu); \
493
484
context->SetStatus (status); \
494
485
return ; \
495
486
}
487
+ #define REQUIRES_OK (status ) \
488
+ { \
489
+ auto s = (status); \
490
+ REQUIRES (s.ok (), s); \
491
+ }
496
492
497
493
const int64_t num_elements = value.dimension (1 );
498
494
const int32_t * p_value = &value (start, 0 );
@@ -621,7 +617,7 @@ class EntropyDecodeChannelOp : public tensorflow::OpKernel {
621
617
tensorflow::mutex mu;
622
618
workers->ParallelFor (handle.size (), cost_per_unit,
623
619
[&handle, &mu, context, index_stride, &output](
624
- int64 start, int64 limit) {
620
+ int64_t start, int64_t limit) {
625
621
PerShard (handle, index_stride, output, context, &mu,
626
622
start, limit);
627
623
});
@@ -634,19 +630,17 @@ class EntropyDecodeChannelOp : public tensorflow::OpKernel {
634
630
TTypes<int32_t >::Matrix output,
635
631
tensorflow::OpKernelContext* context,
636
632
tensorflow::mutex* mu, int64_t start, int64_t limit) {
637
- #define REQUIRES_OK (status ) \
638
- if (auto s = (status); ABSL_PREDICT_FALSE (!s.ok ())) { \
639
- tensorflow::mutex_lock lock (*mu); \
640
- context->SetStatus (s); \
641
- return ; \
642
- }
643
-
644
633
#define REQUIRES (cond, status ) \
645
634
if (!ABSL_PREDICT_TRUE (cond)) { \
646
635
tensorflow::mutex_lock lock (*mu); \
647
636
context->SetStatus (status); \
648
637
return ; \
649
638
}
639
+ #define REQUIRES_OK (status ) \
640
+ { \
641
+ auto s = (status); \
642
+ REQUIRES (s.ok (), s); \
643
+ }
650
644
651
645
const int64_t num_elements = output.dimension (1 );
652
646
auto * p_output = &output (start, 0 );
@@ -723,7 +717,7 @@ class EntropyDecodeIndexOp : public tensorflow::OpKernel {
723
717
tensorflow::mutex mu;
724
718
workers->ParallelFor (
725
719
handle.size (), cost_per_unit,
726
- [&handle, &mu, context, index, &output](int64 start, int64 limit) {
720
+ [&handle, &mu, context, index, &output](int64_t start, int64_t limit) {
727
721
PerShard (handle, index, output, context, &mu, start, limit);
728
722
});
729
723
@@ -736,19 +730,17 @@ class EntropyDecodeIndexOp : public tensorflow::OpKernel {
736
730
TTypes<int32_t >::Matrix output,
737
731
tensorflow::OpKernelContext* context,
738
732
tensorflow::mutex* mu, int64_t start, int64_t limit) {
739
- #define REQUIRES_OK (status ) \
740
- if (auto s = (status); ABSL_PREDICT_FALSE (!s.ok ())) { \
741
- tensorflow::mutex_lock lock (*mu); \
742
- context->SetStatus (s); \
743
- return ; \
744
- }
745
-
746
733
#define REQUIRES (cond, status ) \
747
734
if (!ABSL_PREDICT_TRUE (cond)) { \
748
735
tensorflow::mutex_lock lock (*mu); \
749
736
context->SetStatus (status); \
750
737
return ; \
751
738
}
739
+ #define REQUIRES_OK (status ) \
740
+ { \
741
+ auto s = (status); \
742
+ REQUIRES (s.ok (), s); \
743
+ }
752
744
753
745
const int64_t num_elements = output.dimension (1 );
754
746
const int32_t * p_index = &index (start, 0 );
0 commit comments