Skip to content

Commit 37006f0

Browse files
PraChetittensorflower-gardener
authored andcommitted
Adds efficient special cases to the default bitpacking utils in tensor_encoding.
PiperOrigin-RevId: 352917915
1 parent 782fd70 commit 37006f0

File tree

2 files changed

+168
-2
lines changed

2 files changed

+168
-2
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ def pack_into_int(value, input_bitrange, target_bitrange):
311311
"""
312312
# TODO(b/161433177): Provide a general solution without memory overhead.
313313
# Special cases implemented without extra memory overhead.
314+
if input_bitrange == 6 and target_bitrange == 28:
315+
return _pack_into_int_6_28(value)
316+
if input_bitrange == 7 and target_bitrange == 28:
317+
return _pack_into_int_7_28(value)
314318
if input_bitrange == 8 and target_bitrange == 28:
315319
return _pack_into_int_8_28(value)
316320
if input_bitrange == 12 and target_bitrange == 28:
@@ -349,6 +353,10 @@ def unpack_from_int(value, original_bitrange, target_bitrange, shape):
349353
"""
350354
# TODO(b/161433177): Provide a general solution without memory overhead.
351355
# Special cases implemented without extra memory overhead.
356+
if original_bitrange == 6 and target_bitrange == 28:
357+
return _unpack_from_int_6_28(value, shape)
358+
if original_bitrange == 7 and target_bitrange == 28:
359+
return _unpack_from_int_7_28(value, shape)
352360
if original_bitrange == 8 and target_bitrange == 28:
353361
return _unpack_from_int_8_28(value, shape)
354362
if original_bitrange == 12 and target_bitrange == 28:
@@ -386,6 +394,85 @@ def _expand_to_binary_form(value, input_bits):
386394
return tf.reshape(bits, [-1])
387395

388396

397+
def _pack_into_int_6_28(value):
398+
"""Implementation of `pack_into_int` for specific bitranges.
399+
400+
This method corresponts to `(input_bitrange, target_bitrange)` form the
401+
`pack_into_int` method equal to `(6, 28)`. This method relies on the fact that
402+
14 values in 6-bit bitrange can be packed into 3 values in 28-bitrange
403+
(14 = least_common_multiple(6, 28) / 6).
404+
405+
It reshapes the input into matrix of 14 columns and performs operations on the
406+
columns of the matrix, thus vectorizing the operations and avoiding memory
407+
overhead of an earlier general implementation.
408+
409+
Args:
410+
value: An integer Tensor to be packed with values in [0, 2**6 - 1].
411+
412+
Returns:
413+
An integer Tensor representing `value` of the same dtype as `value`.
414+
"""
415+
value = tf.reshape(value, [-1])
416+
extra_zeros = tf.zeros(tf.math.mod(-tf.shape(value), 14), value.dtype)
417+
val = tf.reshape(tf.concat([value, extra_zeros], 0), [-1, 14])
418+
419+
a = (val[:, 0] +
420+
val[:, 1] * 2**6 +
421+
val[:, 2] * 2**12 +
422+
val[:, 3] * 2**18 +
423+
tf.math.mod(val[:, 4], 2**4) * 2**24)
424+
b = (tf.math.floordiv(val[:, 4], 2**4) +
425+
val[:, 5] * 2**2 +
426+
val[:, 6] * 2**8 +
427+
val[:, 7] * 2**14 +
428+
val[:, 8] * 2**20 +
429+
tf.math.mod(val[:, 9], 2**2) * 2**26)
430+
c = (tf.math.floordiv(val[:, 9], 2**2) +
431+
val[:, 10] * 2**4 +
432+
val[:, 11] * 2**10 +
433+
val[:, 12] * 2**16 +
434+
val[:, 13] * 2**22)
435+
436+
packed_val = tf.reshape(tf.stack([a, b, c], 1), [-1, 1])
437+
if extra_zeros.shape[0] in [5, 6, 7, 8, 9]:
438+
# We added unnecessary product of zeros to the representation.
439+
packed_val = tf.slice(packed_val, [0, 0], [packed_val.shape[0] - 1, 1])
440+
if extra_zeros.shape[0] in [10, 11, 12, 13]:
441+
# We added unnecessary two products of zeros to the representation.
442+
packed_val = tf.slice(packed_val, [0, 0], [packed_val.shape[0] - 2, 1])
443+
return packed_val
444+
445+
446+
def _pack_into_int_7_28(value):
447+
"""Implementation of `pack_into_int` for specific bitranges.
448+
449+
This method corresponts to `(input_bitrange, target_bitrange)` form the
450+
`pack_into_int` method equal to `(7, 28)`. This method relies on the fact that
451+
4 values in 7-bit bitrange can be packed into 1 value in 28-bitrange
452+
(4 = least_common_multiple(7, 28) / 7).
453+
454+
It reshapes the input into matrix of 4 columns and performs operations on the
455+
columns of the matrix, thus vectorizing the operations and avoiding memory
456+
overhead of an earlier general implementation.
457+
458+
Args:
459+
value: An integer Tensor to be packed with values in [0, 2**7 - 1].
460+
461+
Returns:
462+
An integer Tensor representing `value` of the same dtype as `value`.
463+
"""
464+
value = tf.reshape(value, [-1])
465+
extra_zeros = tf.zeros(tf.math.mod(-tf.shape(value), 4), value.dtype)
466+
val = tf.reshape(tf.concat([value, extra_zeros], 0), [-1, 4])
467+
468+
packed_val = (val[:, 0] +
469+
val[:, 1] * 2**7 +
470+
val[:, 2] * 2**14 +
471+
val[:, 3] * 2**21)
472+
473+
return tf.reshape(packed_val, [-1, 1])
474+
475+
389476
def _pack_into_int_8_28(value):
390477
"""Implementation of `pack_into_int` for specific bitranges.
391478
@@ -466,6 +553,47 @@ def _pack_into_int_12_28(value):
466553
return packed_val
467554

468555

556+
def _unpack_from_int_6_28(value, shape):
557+
"""Inverse operation of `_pack_into_int_6_28`."""
558+
value = tf.reshape(value, [-1])
559+
extra_zeros = tf.zeros(tf.math.mod(-tf.shape(value), 3), value.dtype)
560+
val = tf.reshape(tf.concat([value, extra_zeros], 0), [-1, 3])
561+
562+
a = tf.math.mod(val[:, 0], 2**6)
563+
b = tf.math.mod(tf.math.floordiv(val[:, 0], 2**6), 2**6)
564+
c = tf.math.mod(tf.math.floordiv(val[:, 0], 2**12), 2**6)
565+
d = tf.math.mod(tf.math.floordiv(val[:, 0], 2**18), 2**6)
566+
e = tf.math.floordiv(val[:, 0], 2**24) + tf.math.mod(val[:, 1], 2**2) * 2**4
567+
f = tf.math.mod(tf.math.floordiv(val[:, 1], 2**2), 2**6)
568+
g = tf.math.mod(tf.math.floordiv(val[:, 1], 2**8), 2**6)
569+
h = tf.math.mod(tf.math.floordiv(val[:, 1], 2**14), 2**6)
570+
i = tf.math.mod(tf.math.floordiv(val[:, 1], 2**20), 2**6)
571+
j = tf.math.floordiv(val[:, 1], 2**26) + tf.math.mod(val[:, 2], 2**4) * 2**2
572+
k = tf.math.mod(tf.math.floordiv(val[:, 2], 2**4), 2**6)
573+
l = tf.math.mod(tf.math.floordiv(val[:, 2], 2**10), 2**6)
574+
m = tf.math.mod(tf.math.floordiv(val[:, 2], 2**16), 2**6)
575+
n = tf.math.mod(tf.math.floordiv(val[:, 2], 2**22), 2**6)
576+
577+
unpacked_val = tf.reshape(
578+
tf.stack([a, b, c, d, e, f, g, h, i, j, k, l, m, n], 1), [-1,])
579+
unpacked_val = tf.slice(unpacked_val, [0], [tf.reduce_prod(shape)])
580+
return tf.reshape(unpacked_val, shape)
581+
582+
583+
def _unpack_from_int_7_28(value, shape):
584+
"""Inverse operation of `_pack_into_int_7_28`."""
585+
val = tf.reshape(value, [-1, 1])
586+
587+
a = tf.math.mod(val[:, 0], 2**7)
588+
b = tf.math.mod(tf.math.floordiv(val[:, 0], 2**7), 2**7)
589+
c = tf.math.mod(tf.math.floordiv(val[:, 0], 2**14), 2**7)
590+
d = tf.math.mod(tf.math.floordiv(val[:, 0], 2**21), 2**7)
591+
592+
unpacked_val = tf.reshape(tf.stack([a, b, c, d], 1), [-1,])
593+
unpacked_val = tf.slice(unpacked_val, [0], [tf.reduce_prod(shape)])
594+
return tf.reshape(unpacked_val, shape)
595+
596+
469597
def _unpack_from_int_8_28(value, shape):
470598
"""Inverse operation of `_pack_into_int_8_28`."""
471599
value = tf.reshape(value, [-1])

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils_test.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ def test_unpack_from_int_different_outputs(self):
375375
packed_value, original_bitrange=4, target_bitrange=28, shape=(2,))
376376
self.assertAllEqual([9, 0], self.evaluate(unpacked_value))
377377

378-
@parameterized.parameters([(1, 28), (2, 28), (8, 28), (12, 28)])
378+
@parameterized.parameters([(1, 28), (2, 28), (6, 28), (7, 28), (8, 28),
379+
(12, 28)])
379380
def test_boundary_conditions(self, input_bitrange, target_bitrange):
380381
max_v = 2**input_bitrange - 1
381382
input_value = tf.constant(
@@ -390,7 +391,8 @@ def test_boundary_conditions(self, input_bitrange, target_bitrange):
390391

391392
self.assertAllEqual(self.evaluate(value), self.evaluate(unpacked_value))
392393

393-
@parameterized.parameters([(1, 28), (2, 28), (8, 28), (12, 28)])
394+
@parameterized.parameters([(1, 28), (2, 28), (6, 28), (7, 28), (8, 28),
395+
(12, 28)])
394396
def test_random_input(self, input_bitrange, target_bitrange):
395397
# Tests that packing/unpacking amounts to identity, regardless of the input.
396398
num_elements = np.random.randint(low=1, high=50)
@@ -407,6 +409,42 @@ def test_random_input(self, input_bitrange, target_bitrange):
407409
except: # pylint: disable=bare-except
408410
self.fail(f'Random input test failed with input value: {value}')
409411

412+
def test_pack_into_int_special_case_6_28(self):
413+
value = tf.constant(
414+
[50, 19, 51, 59, 10, 53, 36, 44, 31, 44, 31, 10, 31, 56, 49, 48, 35])
415+
packed_value = tf_utils.pack_into_int(
416+
value, input_bitrange=6, target_bitrange=28)
417+
expected_packed_value = tf.constant([[183448818], [33236180], [236923387],
418+
[146481]])
419+
self.assertAllEqual(self.evaluate(expected_packed_value),
420+
self.evaluate(packed_value))
421+
422+
def test_unpack_from_int_special_case_6_28(self):
423+
packed_value = tf.constant([[183448818], [33236180], [236923387], [146481]])
424+
unpacked_value = tf_utils.unpack_from_int(
425+
packed_value, original_bitrange=6, target_bitrange=28, shape=(17,))
426+
expected_unpacked_value = tf.constant(
427+
[50, 19, 51, 59, 10, 53, 36, 44, 31, 44, 31, 10, 31, 56, 49, 48, 35])
428+
self.assertAllEqual(self.evaluate(expected_unpacked_value),
429+
self.evaluate(unpacked_value))
430+
431+
def test_pack_into_int_special_case_7_28(self):
432+
value = tf.constant([117, 86, 42, 69, 9, 70, 66, 8, 112, 116])
433+
packed_value = tf_utils.pack_into_int(
434+
value, input_bitrange=7, target_bitrange=28)
435+
expected_packed_value = tf.constant([[145402741], [17867529], [14960]])
436+
self.assertAllEqual(self.evaluate(expected_packed_value),
437+
self.evaluate(packed_value))
438+
439+
def test_unpack_from_int_special_case_7_28(self):
440+
packed_value = tf.constant([[145402741], [17867529], [14960]])
441+
unpacked_value = tf_utils.unpack_from_int(
442+
packed_value, original_bitrange=7, target_bitrange=28, shape=(10,))
443+
expected_unpacked_value = tf.constant(
444+
[117, 86, 42, 69, 9, 70, 66, 8, 112, 116])
445+
self.assertAllEqual(self.evaluate(expected_unpacked_value),
446+
self.evaluate(unpacked_value))
447+
410448
def test_pack_into_int_special_case_8_28(self):
411449
value = tf.constant([38, 147, 1, 201, 205, 36, 155, 78, 163, 98])
412450
packed_value = tf_utils.pack_into_int(

0 commit comments

Comments
 (0)