@@ -107,22 +107,28 @@ class RunLengthEncodeOp : public OpKernel {
107
107
const int32_t * const end = data.data () + data.size ();
108
108
const int32_t * p = data.data ();
109
109
110
+ // If we encode both zeros and non-zeros with run-length encoding
111
+ // (use_run_length_for_non_zeros == true), only the first zero run length
112
+ // can possibly be zero. We can subtract 1 from all subsequent run lengths.
113
+ int32_t run_length_offset = 0 ;
114
+
110
115
while (p < end) {
111
116
// Find next non-zero.
112
117
const int32_t * q = std::find_if_not (p, end,
113
118
[](int32_t x) { return x == 0 ; });
114
- WriteRunLength (enc, q - p);
119
+ WriteRunLength (enc, q - p - run_length_offset );
115
120
p = q;
116
121
117
122
if (!(p < end)) break ;
118
123
119
124
if (use_run_length_for_non_zeros_) {
120
125
// Find next zero.
121
126
q = std::find_if (p, end, [](int32_t x) { return x == 0 ; });
122
- WriteRunLength (enc, q - p);
127
+ WriteRunLength (enc, q - p - 1 );
123
128
while (p < q) {
124
129
WriteNonZero (enc, *p++);
125
130
}
131
+ run_length_offset = 1 ;
126
132
} else {
127
133
WriteNonZero (enc, *p++);
128
134
}
@@ -211,12 +217,17 @@ class RunLengthDecodeOp : public OpKernel {
211
217
int32_t * const end = data.data () + data.size ();
212
218
int32_t * p = data.data ();
213
219
220
+ // If we encode both zeros and non-zeros with run-length encoding
221
+ // (use_run_length_for_non_zeros == true), only the first zero run length
222
+ // can possibly be zero. We can subtract 1 from all subsequent run lengths.
223
+ int32_t run_length_offset = 0 ;
224
+
214
225
while (p < end) {
215
226
// Skip to the next non-zero element.
216
227
auto run_length = ReadRunLength (context, dec);
217
228
OP_REQUIRES_OK_ABSL (context, run_length.status ());
218
229
219
- p += *run_length;
230
+ p += *run_length + run_length_offset ;
220
231
221
232
if (!(p < end)) {
222
233
// Should not be past the last element.
@@ -228,14 +239,15 @@ class RunLengthDecodeOp : public OpKernel {
228
239
if (use_run_length_for_non_zeros_) {
229
240
run_length = ReadRunLength (context, dec);
230
241
OP_REQUIRES_OK_ABSL (context, run_length.status ());
231
- const int32_t * const next_zero = p + *run_length;
242
+ const int32_t * const next_zero = p + *run_length + 1 ;
232
243
OP_REQUIRES (context, next_zero <= end,
233
244
errors::DataLoss (" Decoded past end of tensor." ));
234
245
while (p < next_zero) {
235
246
auto nonzero = ReadNonZero (context, dec);
236
247
OP_REQUIRES_OK_ABSL (context, nonzero.status ());
237
248
*p++ = *nonzero;
238
249
}
250
+ run_length_offset = 1 ;
239
251
} else {
240
252
auto nonzero = ReadNonZero (context, dec);
241
253
OP_REQUIRES_OK_ABSL (context, nonzero.status ());
0 commit comments