Skip to content

Commit f1afde6

Browse files
Johannes Ballécopybara-github
authored andcommitted
Optimizes code for case use_run_length_for_non_zeros == true.
PiperOrigin-RevId: 492286191 Change-Id: I0bbfc4549eb90ae59d627e9e8490284483c966d9
1 parent 09c6ac4 commit f1afde6

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

tensorflow_compression/cc/kernels/run_length_kernels.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,28 @@ class RunLengthEncodeOp : public OpKernel {
107107
const int32_t* const end = data.data() + data.size();
108108
const int32_t* p = data.data();
109109

110+
// If we encode both zeros and non-zeros with run-length encoding
111+
// (use_run_length_for_non_zeros == true), only the first zero run length
112+
// can possibly be zero. We can subtract 1 from all subsequent run lengths.
113+
int32_t run_length_offset = 0;
114+
110115
while (p < end) {
111116
// Find next non-zero.
112117
const int32_t* q = std::find_if_not(p, end,
113118
[](int32_t x) { return x == 0; });
114-
WriteRunLength(enc, q - p);
119+
WriteRunLength(enc, q - p - run_length_offset);
115120
p = q;
116121

117122
if (!(p < end)) break;
118123

119124
if (use_run_length_for_non_zeros_) {
120125
// Find next zero.
121126
q = std::find_if(p, end, [](int32_t x) { return x == 0; });
122-
WriteRunLength(enc, q - p);
127+
WriteRunLength(enc, q - p - 1);
123128
while (p < q) {
124129
WriteNonZero(enc, *p++);
125130
}
131+
run_length_offset = 1;
126132
} else {
127133
WriteNonZero(enc, *p++);
128134
}
@@ -211,12 +217,17 @@ class RunLengthDecodeOp : public OpKernel {
211217
int32_t* const end = data.data() + data.size();
212218
int32_t* p = data.data();
213219

220+
// If we encode both zeros and non-zeros with run-length encoding
221+
// (use_run_length_for_non_zeros == true), only the first zero run length
222+
// can possibly be zero. We can subtract 1 from all subsequent run lengths.
223+
int32_t run_length_offset = 0;
224+
214225
while (p < end) {
215226
// Skip to the next non-zero element.
216227
auto run_length = ReadRunLength(context, dec);
217228
OP_REQUIRES_OK_ABSL(context, run_length.status());
218229

219-
p += *run_length;
230+
p += *run_length + run_length_offset;
220231

221232
if (!(p < end)) {
222233
// Should not be past the last element.
@@ -228,14 +239,15 @@ class RunLengthDecodeOp : public OpKernel {
228239
if (use_run_length_for_non_zeros_) {
229240
run_length = ReadRunLength(context, dec);
230241
OP_REQUIRES_OK_ABSL(context, run_length.status());
231-
const int32_t* const next_zero = p + *run_length;
242+
const int32_t* const next_zero = p + *run_length + 1;
232243
OP_REQUIRES(context, next_zero <= end,
233244
errors::DataLoss("Decoded past end of tensor."));
234245
while (p < next_zero) {
235246
auto nonzero = ReadNonZero(context, dec);
236247
OP_REQUIRES_OK_ABSL(context, nonzero.status());
237248
*p++ = *nonzero;
238249
}
250+
run_length_offset = 1;
239251
} else {
240252
auto nonzero = ReadNonZero(context, dec);
241253
OP_REQUIRES_OK_ABSL(context, nonzero.status());

0 commit comments

Comments
 (0)