@@ -311,6 +311,10 @@ def pack_into_int(value, input_bitrange, target_bitrange):
311
311
"""
312
312
# TODO(b/161433177): Provide a general solution without memory overhead.
313
313
# Special cases implemented without extra memory overhead.
314
+ if input_bitrange == 6 and target_bitrange == 28 :
315
+ return _pack_into_int_6_28 (value )
316
+ if input_bitrange == 7 and target_bitrange == 28 :
317
+ return _pack_into_int_7_28 (value )
314
318
if input_bitrange == 8 and target_bitrange == 28 :
315
319
return _pack_into_int_8_28 (value )
316
320
if input_bitrange == 12 and target_bitrange == 28 :
@@ -349,6 +353,10 @@ def unpack_from_int(value, original_bitrange, target_bitrange, shape):
349
353
"""
350
354
# TODO(b/161433177): Provide a general solution without memory overhead.
351
355
# Special cases implemented without extra memory overhead.
356
+ if original_bitrange == 6 and target_bitrange == 28 :
357
+ return _unpack_from_int_6_28 (value , shape )
358
+ if original_bitrange == 7 and target_bitrange == 28 :
359
+ return _unpack_from_int_7_28 (value , shape )
352
360
if original_bitrange == 8 and target_bitrange == 28 :
353
361
return _unpack_from_int_8_28 (value , shape )
354
362
if original_bitrange == 12 and target_bitrange == 28 :
@@ -386,6 +394,85 @@ def _expand_to_binary_form(value, input_bits):
386
394
return tf .reshape (bits , [- 1 ])
387
395
388
396
397
+ def _pack_into_int_6_28 (value ):
398
+ """Implementation of `pack_into_int` for specific bitranges.
399
+
400
+ This method corresponts to `(input_bitrange, target_bitrange)` form the
401
+ `pack_into_int` method equal to `(6, 28)`. This method relies on the fact that
402
+ 14 values in 6-bit bitrange can be packed into 3 values in 28-bitrange
403
+ (14 = least_common_multiple(6, 28) / 6).
404
+
405
+ It reshapes the input into matrix of 14 columns and performs operations on the
406
+ columns of the matrix, thus vectorizing the operations and avoiding memory
407
+ overhead of an earlier general implementation.
408
+
409
+ Args:
410
+ value: An integer Tensor to be packed with values in [0, 2**6 - 1].
411
+
412
+ Returns:
413
+ An integer Tensor representing `value` of the same dtype as `value`.
414
+ """
415
+ value = tf .reshape (value , [- 1 ])
416
+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 14 ), value .dtype )
417
+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 14 ])
418
+
419
+ a = (val [:, 0 ] +
420
+ val [:, 1 ] * 2 ** 6 +
421
+ val [:, 2 ] * 2 ** 12 +
422
+ val [:, 3 ] * 2 ** 18 +
423
+ tf .math .mod (val [:, 4 ], 2 ** 4 ) * 2 ** 24 )
424
+ b = (tf .math .floordiv (val [:, 4 ], 2 ** 4 ) +
425
+ val [:, 5 ] * 2 ** 2 +
426
+ val [:, 6 ] * 2 ** 8 +
427
+ val [:, 7 ] * 2 ** 14 +
428
+ val [:, 8 ] * 2 ** 20 +
429
+ tf .math .mod (val [:, 9 ], 2 ** 2 ) * 2 ** 26 )
430
+ c = (tf .math .floordiv (val [:, 9 ], 2 ** 2 ) +
431
+ val [:, 10 ] * 2 ** 4 +
432
+ val [:, 11 ] * 2 ** 10 +
433
+ val [:, 12 ] * 2 ** 16 +
434
+ val [:, 13 ] * 2 ** 22 )
435
+
436
+ packed_val = tf .reshape (tf .stack ([a , b , c ], 1 ), [- 1 , 1 ])
437
+ if extra_zeros .shape [0 ] in [5 , 6 , 7 , 8 , 9 ]:
438
+ # We added unnecessary product of zeros to the representation.
439
+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 1 , 1 ])
440
+ if extra_zeros .shape [0 ] in [10 , 11 , 12 , 13 ]:
441
+ # We added unnecessary two products of zeros to the representation.
442
+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 2 , 1 ])
443
+ return packed_val
444
+
445
+
446
+ def _pack_into_int_7_28 (value ):
447
+ """Implementation of `pack_into_int` for specific bitranges.
448
+
449
+ This method corresponts to `(input_bitrange, target_bitrange)` form the
450
+ `pack_into_int` method equal to `(7, 28)`. This method relies on the fact that
451
+ 4 values in 7-bit bitrange can be packed into 1 value in 28-bitrange
452
+ (4 = least_common_multiple(7, 28) / 7).
453
+
454
+ It reshapes the input into matrix of 4 columns and performs operations on the
455
+ columns of the matrix, thus vectorizing the operations and avoiding memory
456
+ overhead of an earlier general implementation.
457
+
458
+ Args:
459
+ value: An integer Tensor to be packed with values in [0, 2**7 - 1].
460
+
461
+ Returns:
462
+ An integer Tensor representing `value` of the same dtype as `value`.
463
+ """
464
+ value = tf .reshape (value , [- 1 ])
465
+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 4 ), value .dtype )
466
+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 4 ])
467
+
468
+ packed_val = (val [:, 0 ] +
469
+ val [:, 1 ] * 2 ** 7 +
470
+ val [:, 2 ] * 2 ** 14 +
471
+ val [:, 3 ] * 2 ** 21 )
472
+
473
+ return tf .reshape (packed_val , [- 1 , 1 ])
474
+
475
+
389
476
def _pack_into_int_8_28 (value ):
390
477
"""Implementation of `pack_into_int` for specific bitranges.
391
478
@@ -466,6 +553,47 @@ def _pack_into_int_12_28(value):
466
553
return packed_val
467
554
468
555
556
+ def _unpack_from_int_6_28 (value , shape ):
557
+ """Inverse operation of `_pack_into_int_6_28`."""
558
+ value = tf .reshape (value , [- 1 ])
559
+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 3 ), value .dtype )
560
+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 3 ])
561
+
562
+ a = tf .math .mod (val [:, 0 ], 2 ** 6 )
563
+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 6 ), 2 ** 6 )
564
+ c = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 12 ), 2 ** 6 )
565
+ d = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 18 ), 2 ** 6 )
566
+ e = tf .math .floordiv (val [:, 0 ], 2 ** 24 ) + tf .math .mod (val [:, 1 ], 2 ** 2 ) * 2 ** 4
567
+ f = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 2 ), 2 ** 6 )
568
+ g = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 8 ), 2 ** 6 )
569
+ h = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 14 ), 2 ** 6 )
570
+ i = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 20 ), 2 ** 6 )
571
+ j = tf .math .floordiv (val [:, 1 ], 2 ** 26 ) + tf .math .mod (val [:, 2 ], 2 ** 4 ) * 2 ** 2
572
+ k = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 4 ), 2 ** 6 )
573
+ l = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 10 ), 2 ** 6 )
574
+ m = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 16 ), 2 ** 6 )
575
+ n = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 22 ), 2 ** 6 )
576
+
577
+ unpacked_val = tf .reshape (
578
+ tf .stack ([a , b , c , d , e , f , g , h , i , j , k , l , m , n ], 1 ), [- 1 ,])
579
+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
580
+ return tf .reshape (unpacked_val , shape )
581
+
582
+
583
+ def _unpack_from_int_7_28 (value , shape ):
584
+ """Inverse operation of `_pack_into_int_7_28`."""
585
+ val = tf .reshape (value , [- 1 , 1 ])
586
+
587
+ a = tf .math .mod (val [:, 0 ], 2 ** 7 )
588
+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 7 ), 2 ** 7 )
589
+ c = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 14 ), 2 ** 7 )
590
+ d = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 21 ), 2 ** 7 )
591
+
592
+ unpacked_val = tf .reshape (tf .stack ([a , b , c , d ], 1 ), [- 1 ,])
593
+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
594
+ return tf .reshape (unpacked_val , shape )
595
+
596
+
469
597
def _unpack_from_int_8_28 (value , shape ):
470
598
"""Inverse operation of `_pack_into_int_8_28`."""
471
599
value = tf .reshape (value , [- 1 ])
0 commit comments