@@ -311,6 +311,10 @@ def pack_into_int(value, input_bitrange, target_bitrange):
311311 """
312312 # TODO(b/161433177): Provide a general solution without memory overhead.
313313 # Special cases implemented without extra memory overhead.
314+ if input_bitrange == 6 and target_bitrange == 28 :
315+ return _pack_into_int_6_28 (value )
316+ if input_bitrange == 7 and target_bitrange == 28 :
317+ return _pack_into_int_7_28 (value )
314318 if input_bitrange == 8 and target_bitrange == 28 :
315319 return _pack_into_int_8_28 (value )
316320 if input_bitrange == 12 and target_bitrange == 28 :
@@ -349,6 +353,10 @@ def unpack_from_int(value, original_bitrange, target_bitrange, shape):
349353 """
350354 # TODO(b/161433177): Provide a general solution without memory overhead.
351355 # Special cases implemented without extra memory overhead.
356+ if original_bitrange == 6 and target_bitrange == 28 :
357+ return _unpack_from_int_6_28 (value , shape )
358+ if original_bitrange == 7 and target_bitrange == 28 :
359+ return _unpack_from_int_7_28 (value , shape )
352360 if original_bitrange == 8 and target_bitrange == 28 :
353361 return _unpack_from_int_8_28 (value , shape )
354362 if original_bitrange == 12 and target_bitrange == 28 :
@@ -386,6 +394,85 @@ def _expand_to_binary_form(value, input_bits):
386394 return tf .reshape (bits , [- 1 ])
387395
388396
397+ def _pack_into_int_6_28 (value ):
398+ """Implementation of `pack_into_int` for specific bitranges.
399+
400+ This method corresponts to `(input_bitrange, target_bitrange)` form the
401+ `pack_into_int` method equal to `(6, 28)`. This method relies on the fact that
402+ 14 values in 6-bit bitrange can be packed into 3 values in 28-bitrange
403+ (14 = least_common_multiple(6, 28) / 6).
404+
405+ It reshapes the input into matrix of 14 columns and performs operations on the
406+ columns of the matrix, thus vectorizing the operations and avoiding memory
407+ overhead of an earlier general implementation.
408+
409+ Args:
410+ value: An integer Tensor to be packed with values in [0, 2**6 - 1].
411+
412+ Returns:
413+ An integer Tensor representing `value` of the same dtype as `value`.
414+ """
415+ value = tf .reshape (value , [- 1 ])
416+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 14 ), value .dtype )
417+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 14 ])
418+
419+ a = (val [:, 0 ] +
420+ val [:, 1 ] * 2 ** 6 +
421+ val [:, 2 ] * 2 ** 12 +
422+ val [:, 3 ] * 2 ** 18 +
423+ tf .math .mod (val [:, 4 ], 2 ** 4 ) * 2 ** 24 )
424+ b = (tf .math .floordiv (val [:, 4 ], 2 ** 4 ) +
425+ val [:, 5 ] * 2 ** 2 +
426+ val [:, 6 ] * 2 ** 8 +
427+ val [:, 7 ] * 2 ** 14 +
428+ val [:, 8 ] * 2 ** 20 +
429+ tf .math .mod (val [:, 9 ], 2 ** 2 ) * 2 ** 26 )
430+ c = (tf .math .floordiv (val [:, 9 ], 2 ** 2 ) +
431+ val [:, 10 ] * 2 ** 4 +
432+ val [:, 11 ] * 2 ** 10 +
433+ val [:, 12 ] * 2 ** 16 +
434+ val [:, 13 ] * 2 ** 22 )
435+
436+ packed_val = tf .reshape (tf .stack ([a , b , c ], 1 ), [- 1 , 1 ])
437+ if extra_zeros .shape [0 ] in [5 , 6 , 7 , 8 , 9 ]:
438+ # We added unnecessary product of zeros to the representation.
439+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 1 , 1 ])
440+ if extra_zeros .shape [0 ] in [10 , 11 , 12 , 13 ]:
441+ # We added unnecessary two products of zeros to the representation.
442+ packed_val = tf .slice (packed_val , [0 , 0 ], [packed_val .shape [0 ] - 2 , 1 ])
443+ return packed_val
444+
445+
446+ def _pack_into_int_7_28 (value ):
447+ """Implementation of `pack_into_int` for specific bitranges.
448+
449+ This method corresponts to `(input_bitrange, target_bitrange)` form the
450+ `pack_into_int` method equal to `(7, 28)`. This method relies on the fact that
451+ 4 values in 7-bit bitrange can be packed into 1 value in 28-bitrange
452+ (4 = least_common_multiple(7, 28) / 7).
453+
454+ It reshapes the input into matrix of 4 columns and performs operations on the
455+ columns of the matrix, thus vectorizing the operations and avoiding memory
456+ overhead of an earlier general implementation.
457+
458+ Args:
459+ value: An integer Tensor to be packed with values in [0, 2**7 - 1].
460+
461+ Returns:
462+ An integer Tensor representing `value` of the same dtype as `value`.
463+ """
464+ value = tf .reshape (value , [- 1 ])
465+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 4 ), value .dtype )
466+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 4 ])
467+
468+ packed_val = (val [:, 0 ] +
469+ val [:, 1 ] * 2 ** 7 +
470+ val [:, 2 ] * 2 ** 14 +
471+ val [:, 3 ] * 2 ** 21 )
472+
473+ return tf .reshape (packed_val , [- 1 , 1 ])
474+
475+
389476def _pack_into_int_8_28 (value ):
390477 """Implementation of `pack_into_int` for specific bitranges.
391478
@@ -466,6 +553,47 @@ def _pack_into_int_12_28(value):
466553 return packed_val
467554
468555
556+ def _unpack_from_int_6_28 (value , shape ):
557+ """Inverse operation of `_pack_into_int_6_28`."""
558+ value = tf .reshape (value , [- 1 ])
559+ extra_zeros = tf .zeros (tf .math .mod (- tf .shape (value ), 3 ), value .dtype )
560+ val = tf .reshape (tf .concat ([value , extra_zeros ], 0 ), [- 1 , 3 ])
561+
562+ a = tf .math .mod (val [:, 0 ], 2 ** 6 )
563+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 6 ), 2 ** 6 )
564+ c = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 12 ), 2 ** 6 )
565+ d = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 18 ), 2 ** 6 )
566+ e = tf .math .floordiv (val [:, 0 ], 2 ** 24 ) + tf .math .mod (val [:, 1 ], 2 ** 2 ) * 2 ** 4
567+ f = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 2 ), 2 ** 6 )
568+ g = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 8 ), 2 ** 6 )
569+ h = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 14 ), 2 ** 6 )
570+ i = tf .math .mod (tf .math .floordiv (val [:, 1 ], 2 ** 20 ), 2 ** 6 )
571+ j = tf .math .floordiv (val [:, 1 ], 2 ** 26 ) + tf .math .mod (val [:, 2 ], 2 ** 4 ) * 2 ** 2
572+ k = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 4 ), 2 ** 6 )
573+ l = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 10 ), 2 ** 6 )
574+ m = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 16 ), 2 ** 6 )
575+ n = tf .math .mod (tf .math .floordiv (val [:, 2 ], 2 ** 22 ), 2 ** 6 )
576+
577+ unpacked_val = tf .reshape (
578+ tf .stack ([a , b , c , d , e , f , g , h , i , j , k , l , m , n ], 1 ), [- 1 ,])
579+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
580+ return tf .reshape (unpacked_val , shape )
581+
582+
583+ def _unpack_from_int_7_28 (value , shape ):
584+ """Inverse operation of `_pack_into_int_7_28`."""
585+ val = tf .reshape (value , [- 1 , 1 ])
586+
587+ a = tf .math .mod (val [:, 0 ], 2 ** 7 )
588+ b = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 7 ), 2 ** 7 )
589+ c = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 14 ), 2 ** 7 )
590+ d = tf .math .mod (tf .math .floordiv (val [:, 0 ], 2 ** 21 ), 2 ** 7 )
591+
592+ unpacked_val = tf .reshape (tf .stack ([a , b , c , d ], 1 ), [- 1 ,])
593+ unpacked_val = tf .slice (unpacked_val , [0 ], [tf .reduce_prod (shape )])
594+ return tf .reshape (unpacked_val , shape )
595+
596+
469597def _unpack_from_int_8_28 (value , shape ):
470598 """Inverse operation of `_pack_into_int_8_28`."""
471599 value = tf .reshape (value , [- 1 ])
0 commit comments