1717from __future__ import division
1818from __future__ import print_function
1919
20+ import logging
2021import math
2122
2223import numpy as np
@@ -295,9 +296,9 @@ def pack_into_int(value, input_bitrange, target_bitrange):
295296
296297 NOTE: This only uses basic math operations to implement the bit manipulation,
297298 not any bitwise operations, which is relevant in environments where only a
298- subset of TensorFlow ops/kernels are available. If values outside of the
299- expected range are provided at runtime, an error will *not* be raised,
300- possibly returning an incorrect value .
299+ subset of TensorFlow ops/kernels are available. Moreover, if values outside of
300+ the expected range are provided at runtime, an error will *not* be raised.
301+ The behavior of this method in such case is undefined .
301302
302303 Args:
303304 value: An integer Tensor to be packed.
@@ -308,6 +309,17 @@ def pack_into_int(value, input_bitrange, target_bitrange):
308309 Returns:
309310 An integer Tensor representing `value` of the same dtype as `value`.
310311 """
312+ # TODO(b/161433177): Provide a general solution without memory overhead.
313+ # Special cases implemented without extra memory overhead.
314+ if input_bitrange == 8 and target_bitrange == 28 :
315+ return _pack_into_int_8_28 (value )
316+ if input_bitrange == 12 and target_bitrange == 28 :
317+ return _pack_into_int_12_28 (value )
318+
319+ # General solution with possible extra memory overhead.
320+ logging .warning ('This code path can temporarily allocate extra memory. If '
321+ 'memory footprint is a problem, consider different bitpacking'
322+ ' method or turning this functionality off. See b/161433177' )
311323 if input_bitrange > 1 :
312324 value = tf .reshape (value , [- 1 , 1 ])
313325 value = _expand_to_binary_form (value , input_bitrange )
@@ -335,6 +347,17 @@ def unpack_from_int(value, original_bitrange, target_bitrange, shape):
335347 An integer Tensor representing the unpacked `value` of the same dtype as
336348 `value`.
337349 """
350+ # TODO(b/161433177): Provide a general solution without memory overhead.
351+ # Special cases implemented without extra memory overhead.
352+ if original_bitrange == 8 and target_bitrange == 28 :
353+ return _unpack_from_int_8_28 (value , shape )
354+ if original_bitrange == 12 and target_bitrange == 28 :
355+ return _unpack_from_int_12_28 (value , shape )
356+
357+ # General solution with extra memory overhead.
358+ logging .warning ('This code path can temporarily allocate extra memory. If '
359+ 'memory footprint is a problem, consider different bitpacking'
360+ ' method or turning this functionality off. See b/161433177' )
338361 value = _expand_to_binary_form (value , target_bitrange )
339362 value = tf .slice (value , [0 ], [tf .reduce_prod (shape ) * original_bitrange ])
340363 if original_bitrange > 1 :
@@ -361,3 +384,121 @@ def _expand_to_binary_form(value, input_bits):
361384 expand_vector = tf .constant ([2 ** i for i in range (input_bits )], value .dtype )
362385 bits = tf .math .mod (tf .math .floordiv (value , expand_vector ), 2 )
363386 return tf .reshape (bits , [- 1 ])
387+
388+
389+ def _pack_into_int_8_28 (value ):
390+ """Implementation of `pack_into_int` for specific bitranges.
391+
392+ This method corresponts to `(input_bitrange, target_bitrange)` form the
393+ `pack_into_int` method equal to `(8, 28)`. This method relies on the fact that
394+ 7 values in 8-bit bitrange can be packed into 2 values in 28-bitrange
395+ (7 = least_common_multiple(8, 28) / 8).
396+
397+ It reshapes the input into matrix of 7 columns and performs operations on the
398+ columns of the matrix, thus vectorizing the operations and avoiding memory
399+ overhead of an earlier general implementation.
400+
401+ Args:
402+ value: An integer Tensor to be packed with values in [0, 2**8 - 1].
403+
404+ Returns:
405+ An integer Tensor representing `value` of the same dtype as `value`.
406+ """
407+ value = tf .reshape (value , [- 1 ])
408+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 7 ), value .dtype )
409+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 7 ])
410+
411+ a = (val [:, 0 ] +
412+ val [:, 1 ] * 2 ** 8 +
413+ val [:, 2 ] * 2 ** 16 +
414+ tf .math .mod (val [:, 3 ], 2 ** 4 ) * 2 ** 24 )
415+ b = (tf .math .floordiv (val [:, 3 ], 2 ** 4 ) +
416+ val [:, 4 ] * 2 ** 4 +
417+ val [:, 5 ] * 2 ** 12 +
418+ val [:, 6 ] * 2 ** 20 )
419+
420+ packed_val = tf .reshape (tf .stack ([a , b ], 1 ), [- 1 , 1 ])
421+ if extra_zeros .shape [0 ] in [4 , 5 , 6 ]:
422+ # We added unnecessary sum of zeros to the representation.
423+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 1 , 1 ])
424+ return packed_val
425+
426+
427+ def _pack_into_int_12_28 (value ):
428+ """Implementation of `pack_into_int` for specific bitranges.
429+
430+ This method corresponts to `(input_bitrange, target_bitrange)` form the
431+ `pack_into_int` method equal to `(12, 28)`. This method relies on the fact
432+ that 7 values in 12-bit bitrange can be packed into 3 values in 28-bitrange
433+ (7 = least_common_multiple(12, 28) / 12).
434+
435+ It reshapes the input into matrix of 7 columns and performs operations on the
436+ columns of the matrix, thus vectorizing the operations and avoiding memory
437+ overhead of an earlier general implementation.
438+
439+ Args:
440+ value: An integer Tensor to be packed with values in [0, 2**8 - 1].
441+
442+ Returns:
443+ An integer Tensor representing `value` of the same dtype as `value`.
444+ """
445+ value = tf .reshape (value , [- 1 ])
446+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 7 ), value .dtype )
447+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 7 ])
448+
449+ a = (val [:, 0 ] +
450+ val [:, 1 ] * 2 ** 12 +
451+ tf .math .mod (val [:, 2 ], 2 ** 4 ) * 2 ** 24 )
452+ b = (tf .math .floordiv (val [:, 2 ], 2 ** 4 ) +
453+ val [:, 3 ] * 2 ** 8 +
454+ tf .math .mod (val [:, 4 ], 2 ** 8 ) * 2 ** 20 )
455+ c = (tf .math .floordiv (val [:, 4 ], 2 ** 8 ) +
456+ val [:, 5 ] * 2 ** 4 +
457+ val [:, 6 ] * 2 ** 16 )
458+
459+ packed_val = tf .reshape (tf .stack ([a , b , c ], 1 ), [- 1 , 1 ])
460+ if extra_zeros .shape [0 ] in [3 , 4 ]:
461+ # We added unnecessary sum of zeros to the representation.
462+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 1 , 1 ])
463+ if extra_zeros .shape [0 ] in [5 , 6 ]:
464+ # We added unnecessary two sums of zeros to the representation.
465+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 2 , 1 ])
466+ return packed_val
467+
468+
469+ def _unpack_from_int_8_28 (value , shape ):
470+ """Inverse operation of `_pack_into_int_8_28`."""
471+ value = tf .reshape (value , [- 1 ])
472+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 2 ), value .dtype )
473+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 2 ])
474+
475+ a = tf .math .mod (val [:, 0 ], 2 ** 8 )
476+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 8 ), 2 ** 8 )
477+ c = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 16 ), 2 ** 8 )
478+ d = tf .math .floordiv (val [:, 0 ], 2 ** 24 ) + tf .math .mod (val [:, 1 ], 2 ** 4 ) * 2 ** 4
479+ e = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 4 ), 2 ** 8 )
480+ f = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 12 ), 2 ** 8 )
481+ g = tf .math .floordiv (val [:, 1 ], 2 ** 20 )
482+
483+ unpacked_val = tf .reshape (tf .stack ([a , b , c , d , e , f , g ], 1 ), [- 1 ,])
484+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
485+ return tf .reshape (unpacked_val , shape )
486+
487+
488+ def _unpack_from_int_12_28 (value , shape ):
489+ """Inverse operation of `_pack_into_int_12_28`."""
490+ value = tf .reshape (value , [- 1 ])
491+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 3 ), value .dtype )
492+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 3 ])
493+
494+ a = tf .math .mod (val [:, 0 ], 2 ** 12 )
495+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 12 ), 2 ** 12 )
496+ c = tf .math .floordiv (val [:, 0 ], 2 ** 24 ) + tf .math .mod (val [:, 1 ], 2 ** 8 ) * 2 ** 4
497+ d = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 8 ), 2 ** 12 )
498+ e = tf .math .floordiv (val [:, 1 ], 2 ** 20 ) + tf .math .mod (val [:, 2 ], 2 ** 4 ) * 2 ** 8
499+ f = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 4 ), 2 ** 12 )
500+ g = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 16 ), 2 ** 12 )
501+
502+ unpacked_val = tf .reshape (tf .stack ([a , b , c , d , e , f , g ], 1 ), [- 1 ,])
503+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
504+ return tf .reshape (unpacked_val , shape )
0 commit comments