17
17
from __future__ import division
18
18
from __future__ import print_function
19
19
20
+ import logging
20
21
import math
21
22
22
23
import numpy as np
@@ -295,9 +296,9 @@ def pack_into_int(value, input_bitrange, target_bitrange):
295
296
296
297
NOTE: This only uses basic math operations to implement the bit manipulation,
297
298
not any bitwise operations, which is relevant in environments where only a
298
- subset of TensorFlow ops/kernels are available. If values outside of the
299
- expected range are provided at runtime, an error will *not* be raised,
300
- possibly returning an incorrect value .
299
+ subset of TensorFlow ops/kernels are available. Moreover, if values outside of
300
+ the expected range are provided at runtime, an error will *not* be raised.
301
+ The behavior of this method in such case is undefined .
301
302
302
303
Args:
303
304
value: An integer Tensor to be packed.
@@ -308,6 +309,17 @@ def pack_into_int(value, input_bitrange, target_bitrange):
308
309
Returns:
309
310
An integer Tensor representing `value` of the same dtype as `value`.
310
311
"""
312
+ # TODO(b/161433177): Provide a general solution without memory overhead.
313
+ # Special cases implemented without extra memory overhead.
314
+ if input_bitrange == 8 and target_bitrange == 28 :
315
+ return _pack_into_int_8_28 (value )
316
+ if input_bitrange == 12 and target_bitrange == 28 :
317
+ return _pack_into_int_12_28 (value )
318
+
319
+ # General solution with possible extra memory overhead.
320
+ logging .warning ('This code path can temporarily allocate extra memory. If '
321
+ 'memory footprint is a problem, consider different bitpacking'
322
+ ' method or turning this functionality off. See b/161433177' )
311
323
if input_bitrange > 1 :
312
324
value = tf .reshape (value , [- 1 , 1 ])
313
325
value = _expand_to_binary_form (value , input_bitrange )
@@ -335,6 +347,17 @@ def unpack_from_int(value, original_bitrange, target_bitrange, shape):
335
347
An integer Tensor representing the unpacked `value` of the same dtype as
336
348
`value`.
337
349
"""
350
+ # TODO(b/161433177): Provide a general solution without memory overhead.
351
+ # Special cases implemented without extra memory overhead.
352
+ if original_bitrange == 8 and target_bitrange == 28 :
353
+ return _unpack_from_int_8_28 (value , shape )
354
+ if original_bitrange == 12 and target_bitrange == 28 :
355
+ return _unpack_from_int_12_28 (value , shape )
356
+
357
+ # General solution with extra memory overhead.
358
+ logging .warning ('This code path can temporarily allocate extra memory. If '
359
+ 'memory footprint is a problem, consider different bitpacking'
360
+ ' method or turning this functionality off. See b/161433177' )
338
361
value = _expand_to_binary_form (value , target_bitrange )
339
362
value = tf .slice (value , [0 ], [tf .reduce_prod (shape ) * original_bitrange ])
340
363
if original_bitrange > 1 :
@@ -361,3 +384,121 @@ def _expand_to_binary_form(value, input_bits):
361
384
expand_vector = tf .constant ([2 ** i for i in range (input_bits )], value .dtype )
362
385
bits = tf .math .mod (tf .math .floordiv (value , expand_vector ), 2 )
363
386
return tf .reshape (bits , [- 1 ])
387
+
388
+
389
+ def _pack_into_int_8_28 (value ):
390
+ """Implementation of `pack_into_int` for specific bitranges.
391
+
392
+ This method corresponts to `(input_bitrange, target_bitrange)` form the
393
+ `pack_into_int` method equal to `(8, 28)`. This method relies on the fact that
394
+ 7 values in 8-bit bitrange can be packed into 2 values in 28-bitrange
395
+ (7 = least_common_multiple(8, 28) / 8).
396
+
397
+ It reshapes the input into matrix of 7 columns and performs operations on the
398
+ columns of the matrix, thus vectorizing the operations and avoiding memory
399
+ overhead of an earlier general implementation.
400
+
401
+ Args:
402
+ value: An integer Tensor to be packed with values in [0, 2**8 - 1].
403
+
404
+ Returns:
405
+ An integer Tensor representing `value` of the same dtype as `value`.
406
+ """
407
+ value = tf .reshape (value , [- 1 ])
408
+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 7 ), value .dtype )
409
+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 7 ])
410
+
411
+ a = (val [:, 0 ] +
412
+ val [:, 1 ] * 2 ** 8 +
413
+ val [:, 2 ] * 2 ** 16 +
414
+ tf .math .mod (val [:, 3 ], 2 ** 4 ) * 2 ** 24 )
415
+ b = (tf .math .floordiv (val [:, 3 ], 2 ** 4 ) +
416
+ val [:, 4 ] * 2 ** 4 +
417
+ val [:, 5 ] * 2 ** 12 +
418
+ val [:, 6 ] * 2 ** 20 )
419
+
420
+ packed_val = tf .reshape (tf .stack ([a , b ], 1 ), [- 1 , 1 ])
421
+ if extra_zeros .shape [0 ] in [4 , 5 , 6 ]:
422
+ # We added unnecessary sum of zeros to the representation.
423
+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 1 , 1 ])
424
+ return packed_val
425
+
426
+
427
+ def _pack_into_int_12_28 (value ):
428
+ """Implementation of `pack_into_int` for specific bitranges.
429
+
430
+ This method corresponts to `(input_bitrange, target_bitrange)` form the
431
+ `pack_into_int` method equal to `(12, 28)`. This method relies on the fact
432
+ that 7 values in 12-bit bitrange can be packed into 3 values in 28-bitrange
433
+ (7 = least_common_multiple(12, 28) / 12).
434
+
435
+ It reshapes the input into matrix of 7 columns and performs operations on the
436
+ columns of the matrix, thus vectorizing the operations and avoiding memory
437
+ overhead of an earlier general implementation.
438
+
439
+ Args:
440
+ value: An integer Tensor to be packed with values in [0, 2**8 - 1].
441
+
442
+ Returns:
443
+ An integer Tensor representing `value` of the same dtype as `value`.
444
+ """
445
+ value = tf .reshape (value , [- 1 ])
446
+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 7 ), value .dtype )
447
+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 7 ])
448
+
449
+ a = (val [:, 0 ] +
450
+ val [:, 1 ] * 2 ** 12 +
451
+ tf .math .mod (val [:, 2 ], 2 ** 4 ) * 2 ** 24 )
452
+ b = (tf .math .floordiv (val [:, 2 ], 2 ** 4 ) +
453
+ val [:, 3 ] * 2 ** 8 +
454
+ tf .math .mod (val [:, 4 ], 2 ** 8 ) * 2 ** 20 )
455
+ c = (tf .math .floordiv (val [:, 4 ], 2 ** 8 ) +
456
+ val [:, 5 ] * 2 ** 4 +
457
+ val [:, 6 ] * 2 ** 16 )
458
+
459
+ packed_val = tf .reshape (tf .stack ([a , b , c ], 1 ), [- 1 , 1 ])
460
+ if extra_zeros .shape [0 ] in [3 , 4 ]:
461
+ # We added unnecessary sum of zeros to the representation.
462
+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 1 , 1 ])
463
+ if extra_zeros .shape [0 ] in [5 , 6 ]:
464
+ # We added unnecessary two sums of zeros to the representation.
465
+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 2 , 1 ])
466
+ return packed_val
467
+
468
+
469
+ def _unpack_from_int_8_28 (value , shape ):
470
+ """Inverse operation of `_pack_into_int_8_28`."""
471
+ value = tf .reshape (value , [- 1 ])
472
+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 2 ), value .dtype )
473
+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 2 ])
474
+
475
+ a = tf .math .mod (val [:, 0 ], 2 ** 8 )
476
+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 8 ), 2 ** 8 )
477
+ c = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 16 ), 2 ** 8 )
478
+ d = tf .math .floordiv (val [:, 0 ], 2 ** 24 ) + tf .math .mod (val [:, 1 ], 2 ** 4 ) * 2 ** 4
479
+ e = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 4 ), 2 ** 8 )
480
+ f = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 12 ), 2 ** 8 )
481
+ g = tf .math .floordiv (val [:, 1 ], 2 ** 20 )
482
+
483
+ unpacked_val = tf .reshape (tf .stack ([a , b , c , d , e , f , g ], 1 ), [- 1 ,])
484
+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
485
+ return tf .reshape (unpacked_val , shape )
486
+
487
+
488
+ def _unpack_from_int_12_28 (value , shape ):
489
+ """Inverse operation of `_pack_into_int_12_28`."""
490
+ value = tf .reshape (value , [- 1 ])
491
+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 3 ), value .dtype )
492
+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 3 ])
493
+
494
+ a = tf .math .mod (val [:, 0 ], 2 ** 12 )
495
+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 12 ), 2 ** 12 )
496
+ c = tf .math .floordiv (val [:, 0 ], 2 ** 24 ) + tf .math .mod (val [:, 1 ], 2 ** 8 ) * 2 ** 4
497
+ d = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 8 ), 2 ** 12 )
498
+ e = tf .math .floordiv (val [:, 1 ], 2 ** 20 ) + tf .math .mod (val [:, 2 ], 2 ** 4 ) * 2 ** 8
499
+ f = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 4 ), 2 ** 12 )
500
+ g = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 16 ), 2 ** 12 )
501
+
502
+ unpacked_val = tf .reshape (tf .stack ([a , b , c , d , e , f , g ], 1 ), [- 1 ,])
503
+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
504
+ return tf .reshape (unpacked_val , shape )
0 commit comments