Skip to content

Commit 5887583

Browse files
PraChetittensorflower-gardener
authored andcommitted
Partially addresses memory overhead in bitpacking.
This change provides alternative implementation for bitpacking with specific bitranges. The current implementation can temporarily allocate unnecessary memory, which was confirmed to not be the case with the implementation provided here. A more general solution is desired, as this only applies to special cases. PiperOrigin-RevId: 338018083
1 parent f1beeb7 commit 5887583

File tree

2 files changed

+213
-3
lines changed

2 files changed

+213
-3
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import logging
2021
import math
2122

2223
import numpy as np
@@ -295,9 +296,9 @@ def pack_into_int(value, input_bitrange, target_bitrange):
295296
296297
NOTE: This only uses basic math operations to implement the bit manipulation,
297298
not any bitwise operations, which is relevant in environments where only a
298-
subset of TensorFlow ops/kernels are available. If values outside of the
299-
expected range are provided at runtime, an error will *not* be raised,
300-
possibly returning an incorrect value.
299+
subset of TensorFlow ops/kernels are available. Moreover, if values outside of
300+
the expected range are provided at runtime, an error will *not* be raised.
301+
The behavior of this method in such case is undefined.
301302
302303
Args:
303304
value: An integer Tensor to be packed.
@@ -308,6 +309,17 @@ def pack_into_int(value, input_bitrange, target_bitrange):
308309
Returns:
309310
An integer Tensor representing `value` of the same dtype as `value`.
310311
"""
312+
# TODO(b/161433177): Provide a general solution without memory overhead.
313+
# Special cases implemented without extra memory overhead.
314+
if input_bitrange == 8 and target_bitrange == 28:
315+
return _pack_into_int_8_28(value)
316+
if input_bitrange == 12 and target_bitrange == 28:
317+
return _pack_into_int_12_28(value)
318+
319+
# General solution with possible extra memory overhead.
320+
logging.warning('This code path can temporarily allocate extra memory. If '
321+
'memory footprint is a problem, consider different bitpacking'
322+
' method or turning this functionality off. See b/161433177')
311323
if input_bitrange > 1:
312324
value = tf.reshape(value, [-1, 1])
313325
value = _expand_to_binary_form(value, input_bitrange)
@@ -335,6 +347,17 @@ def unpack_from_int(value, original_bitrange, target_bitrange, shape):
335347
An integer Tensor representing the unpacked `value` of the same dtype as
336348
`value`.
337349
"""
350+
# TODO(b/161433177): Provide a general solution without memory overhead.
351+
# Special cases implemented without extra memory overhead.
352+
if original_bitrange == 8 and target_bitrange == 28:
353+
return _unpack_from_int_8_28(value, shape)
354+
if original_bitrange == 12 and target_bitrange == 28:
355+
return _unpack_from_int_12_28(value, shape)
356+
357+
# General solution with extra memory overhead.
358+
logging.warning('This code path can temporarily allocate extra memory. If '
359+
'memory footprint is a problem, consider different bitpacking'
360+
' method or turning this functionality off. See b/161433177')
338361
value = _expand_to_binary_form(value, target_bitrange)
339362
value = tf.slice(value, [0], [tf.reduce_prod(shape) * original_bitrange])
340363
if original_bitrange > 1:
@@ -361,3 +384,121 @@ def _expand_to_binary_form(value, input_bits):
361384
expand_vector = tf.constant([2**i for i in range(input_bits)], value.dtype)
362385
bits = tf.math.mod(tf.math.floordiv(value, expand_vector), 2)
363386
return tf.reshape(bits, [-1])
387+
388+
389+
def _pack_into_int_8_28(value):
390+
"""Implementation of `pack_into_int` for specific bitranges.
391+
392+
This method corresponts to `(input_bitrange, target_bitrange)` form the
393+
`pack_into_int` method equal to `(8, 28)`. This method relies on the fact that
394+
7 values in 8-bit bitrange can be packed into 2 values in 28-bitrange
395+
(7 = least_common_multiple(8, 28) / 8).
396+
397+
It reshapes the input into matrix of 7 columns and performs operations on the
398+
columns of the matrix, thus vectorizing the operations and avoiding memory
399+
overhead of an earlier general implementation.
400+
401+
Args:
402+
value: An integer Tensor to be packed with values in [0, 2**8 - 1].
403+
404+
Returns:
405+
An integer Tensor representing `value` of the same dtype as `value`.
406+
"""
407+
value = tf.reshape(value, [-1])
408+
extra_zeros = tf.zeros(tf.math.mod(-tf.shape(value), 7), value.dtype)
409+
val = tf.reshape(tf.concat([value, extra_zeros], 0), [-1, 7])
410+
411+
a = (val[:, 0] +
412+
val[:, 1] * 2**8 +
413+
val[:, 2] * 2**16 +
414+
tf.math.mod(val[:, 3], 2**4) * 2**24)
415+
b = (tf.math.floordiv(val[:, 3], 2**4) +
416+
val[:, 4] * 2**4 +
417+
val[:, 5] * 2**12 +
418+
val[:, 6] * 2**20)
419+
420+
packed_val = tf.reshape(tf.stack([a, b], 1), [-1, 1])
421+
if extra_zeros.shape[0] in [4, 5, 6]:
422+
# We added unnecessary sum of zeros to the representation.
423+
packed_val = tf.slice(packed_val, [0, 0], [packed_val.shape[0] - 1, 1])
424+
return packed_val
425+
426+
427+
def _pack_into_int_12_28(value):
428+
"""Implementation of `pack_into_int` for specific bitranges.
429+
430+
This method corresponts to `(input_bitrange, target_bitrange)` form the
431+
`pack_into_int` method equal to `(12, 28)`. This method relies on the fact
432+
that 7 values in 12-bit bitrange can be packed into 3 values in 28-bitrange
433+
(7 = least_common_multiple(12, 28) / 12).
434+
435+
It reshapes the input into matrix of 7 columns and performs operations on the
436+
columns of the matrix, thus vectorizing the operations and avoiding memory
437+
overhead of an earlier general implementation.
438+
439+
Args:
440+
value: An integer Tensor to be packed with values in [0, 2**8 - 1].
441+
442+
Returns:
443+
An integer Tensor representing `value` of the same dtype as `value`.
444+
"""
445+
value = tf.reshape(value, [-1])
446+
extra_zeros = tf.zeros(tf.math.mod(-tf.shape(value), 7), value.dtype)
447+
val = tf.reshape(tf.concat([value, extra_zeros], 0), [-1, 7])
448+
449+
a = (val[:, 0] +
450+
val[:, 1] * 2**12 +
451+
tf.math.mod(val[:, 2], 2**4) * 2**24)
452+
b = (tf.math.floordiv(val[:, 2], 2**4) +
453+
val[:, 3] * 2**8 +
454+
tf.math.mod(val[:, 4], 2**8) * 2**20)
455+
c = (tf.math.floordiv(val[:, 4], 2**8) +
456+
val[:, 5] * 2**4 +
457+
val[:, 6] * 2**16)
458+
459+
packed_val = tf.reshape(tf.stack([a, b, c], 1), [-1, 1])
460+
if extra_zeros.shape[0] in [3, 4]:
461+
# We added unnecessary sum of zeros to the representation.
462+
packed_val = tf.slice(packed_val, [0, 0], [packed_val.shape[0] - 1, 1])
463+
if extra_zeros.shape[0] in [5, 6]:
464+
# We added unnecessary two sums of zeros to the representation.
465+
packed_val = tf.slice(packed_val, [0, 0], [packed_val.shape[0] - 2, 1])
466+
return packed_val
467+
468+
469+
def _unpack_from_int_8_28(value, shape):
470+
"""Inverse operation of `_pack_into_int_8_28`."""
471+
value = tf.reshape(value, [-1])
472+
extra_zeros = tf.zeros(tf.math.mod(-tf.shape(value), 2), value.dtype)
473+
val = tf.reshape(tf.concat([value, extra_zeros], 0), [-1, 2])
474+
475+
a = tf.math.mod(val[:, 0], 2**8)
476+
b = tf.math.mod(tf.math.floordiv(val[:, 0], 2**8), 2**8)
477+
c = tf.math.mod(tf.math.floordiv(val[:, 0], 2**16), 2**8)
478+
d = tf.math.floordiv(val[:, 0], 2**24) + tf.math.mod(val[:, 1], 2**4) * 2**4
479+
e = tf.math.mod(tf.math.floordiv(val[:, 1], 2**4), 2**8)
480+
f = tf.math.mod(tf.math.floordiv(val[:, 1], 2**12), 2**8)
481+
g = tf.math.floordiv(val[:, 1], 2**20)
482+
483+
unpacked_val = tf.reshape(tf.stack([a, b, c, d, e, f, g], 1), [-1,])
484+
unpacked_val = tf.slice(unpacked_val, [0], [tf.reduce_prod(shape)])
485+
return tf.reshape(unpacked_val, shape)
486+
487+
488+
def _unpack_from_int_12_28(value, shape):
489+
"""Inverse operation of `_pack_into_int_12_28`."""
490+
value = tf.reshape(value, [-1])
491+
extra_zeros = tf.zeros(tf.math.mod(-tf.shape(value), 3), value.dtype)
492+
val = tf.reshape(tf.concat([value, extra_zeros], 0), [-1, 3])
493+
494+
a = tf.math.mod(val[:, 0], 2**12)
495+
b = tf.math.mod(tf.math.floordiv(val[:, 0], 2**12), 2**12)
496+
c = tf.math.floordiv(val[:, 0], 2**24) + tf.math.mod(val[:, 1], 2**8) * 2**4
497+
d = tf.math.mod(tf.math.floordiv(val[:, 1], 2**8), 2**12)
498+
e = tf.math.floordiv(val[:, 1], 2**20) + tf.math.mod(val[:, 2], 2**4) * 2**8
499+
f = tf.math.mod(tf.math.floordiv(val[:, 2], 2**4), 2**12)
500+
g = tf.math.mod(tf.math.floordiv(val[:, 2], 2**16), 2**12)
501+
502+
unpacked_val = tf.reshape(tf.stack([a, b, c, d, e, f, g], 1), [-1,])
503+
unpacked_val = tf.slice(unpacked_val, [0], [tf.reduce_prod(shape)])
504+
return tf.reshape(unpacked_val, shape)

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,75 @@ def test_unpack_from_int_different_outputs(self):
375375
packed_value, original_bitrange=4, target_bitrange=28, shape=(2,))
376376
self.assertAllEqual([9, 0], self.evaluate(unpacked_value))
377377

378+
@parameterized.parameters([(1, 28), (2, 28), (8, 28), (12, 28)])
379+
def test_boundary_conditions(self, input_bitrange, target_bitrange):
380+
max_v = 2**input_bitrange - 1
381+
input_value = tf.constant(
382+
[0, 0, 0, max_v, max_v, max_v, 0, 0, max_v, max_v, 0, max_v] * 20)
383+
384+
for length in [1, 6, 7, 8, 20*6, 20*7, 20*8, 20*12]:
385+
value = input_value[:length]
386+
packed_value = tf_utils.pack_into_int(value, input_bitrange,
387+
target_bitrange)
388+
unpacked_value = tf_utils.unpack_from_int(packed_value, input_bitrange,
389+
target_bitrange, value.shape)
390+
391+
self.assertAllEqual(self.evaluate(value), self.evaluate(unpacked_value))
392+
393+
@parameterized.parameters([(1, 28), (2, 28), (8, 28), (12, 28)])
394+
def test_random_input(self, input_bitrange, target_bitrange):
395+
# Tests that packing/unpacking amounts to identity, regardless of the input.
396+
num_elements = np.random.randint(low=1, high=50)
397+
value = tf.constant(
398+
np.random.randint(low=0, high=2**input_bitrange, size=num_elements))
399+
packed_value = tf_utils.pack_into_int(value, input_bitrange,
400+
target_bitrange)
401+
unpacked_value = tf_utils.unpack_from_int(packed_value, input_bitrange,
402+
target_bitrange, value.shape)
403+
404+
value, unpacked_value = self.evaluate((value, unpacked_value))
405+
try:
406+
self.assertAllEqual(value, unpacked_value)
407+
except: # pylint: disable=bare-except
408+
self.fail(f'Random input test failed with input value: {value}')
409+
410+
def test_pack_into_int_special_case_8_28(self):
411+
value = tf.constant([38, 147, 1, 201, 205, 36, 155, 78, 163, 98])
412+
packed_value = tf_utils.pack_into_int(
413+
value, input_bitrange=8, target_bitrange=28)
414+
expected_packed_value = tf.constant([[151098150], [162680028], [6464334]])
415+
self.assertAllEqual(self.evaluate(expected_packed_value),
416+
self.evaluate(packed_value))
417+
418+
def test_unpack_from_int_special_case_8_28(self):
419+
packed_value = tf.constant([[151098150], [162680028], [6464334]])
420+
unpacked_value = tf_utils.unpack_from_int(
421+
packed_value, original_bitrange=8, target_bitrange=28, shape=(10,))
422+
expected_unpacked_value = tf.constant(
423+
[38, 147, 1, 201, 205, 36, 155, 78, 163, 98])
424+
self.assertAllEqual(self.evaluate(expected_unpacked_value),
425+
self.evaluate(unpacked_value))
426+
427+
def test_pack_into_int_special_case_12_28(self):
428+
value = tf.constant(
429+
[2805, 3264, 2344, 3472, 2962, 768, 2867, 3703, 2883, 2406])
430+
packed_value = tf_utils.pack_into_int(
431+
value, input_bitrange=12, target_bitrange=28)
432+
expected_packed_value = tf.constant(
433+
[[147589877], [153981074], [187904011], [112475767], [150]])
434+
self.assertAllEqual(self.evaluate(expected_packed_value),
435+
self.evaluate(packed_value))
436+
437+
def test_unpack_from_int_special_case_12_28(self):
438+
packed_value = tf.constant(
439+
[[147589877], [153981074], [187904011], [112475767], [150]])
440+
unpacked_value = tf_utils.unpack_from_int(
441+
packed_value, original_bitrange=12, target_bitrange=28, shape=(10,))
442+
expected_unpacked_value = tf.constant(
443+
[2805, 3264, 2344, 3472, 2962, 768, 2867, 3703, 2883, 2406])
444+
self.assertAllEqual(self.evaluate(expected_unpacked_value),
445+
self.evaluate(unpacked_value))
446+
378447

379448
if __name__ == '__main__':
380449
tf.test.main()

0 commit comments

Comments
 (0)