Skip to content

Commit 3c8d945

Browse files
dengyinlintensorflower-gardener
authored andcommitted
Changes fast_walsh_hadamard_transform to use tf.while_loop.
PiperOrigin-RevId: 263081103
1 parent db5c05b commit 3c8d945

File tree

2 files changed

+98
-11
lines changed

2 files changed

+98
-11
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,52 @@ def fast_walsh_hadamard_transform(x):
3131
3232
Args:
3333
x: A `Tensor`. Must be of shape `[a, b]`, where `a` can be anything (not
34-
necessarily known), and `b` must be a power of two, statically known.
34+
necessarily known), and `b` must be a power of two, not required to be
35+
statically known.
3536
3637
Returns:
3738
A `Tensor` of shape `[a, b]`, where `[i, :]` is the product `x[i, :]*H`,
3839
where `H` is the Hadamard matrix.
3940
4041
Raises:
4142
ValueError: If the input is not rank 2 `Tensor`, and if the second dimension
42-
is not a power of two.
43+
is statically known and is not a power of two.
44+
OpError: If the second dimension is not statically known and is not a power
45+
of two. Note that in graph execution, this error is not raised during the
46+
execution of the Python function, but during execution of the resulting
47+
computation.
4348
"""
4449
with tf.name_scope(None, 'fast_walsh_hadamard_transform'):
4550
# Validate input.
4651
x = tf.convert_to_tensor(x)
4752
if x.shape.ndims != 2:
4853
raise ValueError(
4954
'Number of dimensions of x must be 2. Shape of x: %s' % x.shape)
50-
dim = x.shape.as_list()[1]
51-
if not (dim and ((dim & (dim - 1)) == 0)):
52-
raise ValueError('The dimension of x must be a power of two. '
53-
'Provided dimension is: %s' % dim)
55+
56+
original_x_shape = x.shape.as_list()
57+
dim = x.shape.as_list()[-1]
58+
59+
if dim is None: # dim is not statically known.
60+
dim = tf.shape(x)[-1]
61+
log2 = tf.cast(
62+
tf.math.round(
63+
tf.math.log(tf.cast(dim, tf.float32)) / tf.math.log(2.)),
64+
tf.int32)
65+
with tf.control_dependencies([
66+
tf.compat.v1.assert_equal(
67+
dim,
68+
tf.math.pow(2, log2),
69+
message='The dimension of x must be a power of two.'
70+
'Provided dimension is: %s' % dim)
71+
]):
72+
x = tf.identity(x)
73+
else: # dim is statically known.
74+
if not (dim and ((dim & (dim - 1)) == 0)):
75+
raise ValueError('The dimension of x must be a power of two. '
76+
'Provided dimension is: %s' % dim)
77+
log2 = int(np.ceil(np.log2(dim)))
78+
if dim == 1: # Equivalent to identity.
79+
return tf.identity(x)
5480

5581
h_core = tf.constant([[1., 1.], [1., -1.]],
5682
dtype=x.dtype,
@@ -60,17 +86,30 @@ def fast_walsh_hadamard_transform(x):
6086
# A step of the fast Walsh-Hadamard algorithm.
6187
def _hadamard_step(x, dim):
6288
"""A single step in the fast Walsh-Hadamard transform."""
89+
x_shape = x.shape.as_list()
6390
x = tf.reshape(x, [-1, 2]) # Reshape so that we have a matrix.
6491
x = tf.matmul(x, h_core) # Multiply.
6592
x = tf.reshape(x, [-1, dim // 2, 2]) # Reshape to rank-3.
6693
x = tf.transpose(x, perm=permutation) # Swap last two dimensions.
94+
x.set_shape(x_shape) # Failed shape inference in tf.while_loop.
6795
return x
6896

69-
# The fast Walsh-Hadamard transform.
70-
for _ in range(int(np.ceil(np.log2(dim)))):
71-
x = _hadamard_step(x, dim)
97+
def _fwht(x, dim, log2):
98+
x = tf.reshape(x, [-1, 2, dim // 2])
99+
# The fast Walsh-Hadamard transform.
100+
101+
i = tf.constant(0)
102+
c = lambda i, x: tf.less(i, log2)
103+
b = lambda i, x: [i + 1, _hadamard_step(x, dim)]
104+
i, x = tf.while_loop(c, b, [i, x])
105+
return x
106+
107+
x = tf.cond(
108+
tf.equal(dim, 1), lambda: tf.identity(x), lambda: _fwht(x, dim, log2))
109+
72110
x = tf.reshape(x, [-1, dim])
73111
x /= tf.sqrt(tf.cast(dim, x.dtype)) # Normalize.
112+
x.set_shape(original_x_shape) # Failed shape inference after tf.while_loop.
74113
return x
75114

76115

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils_test.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,61 @@ def test_illegal_inputs_shape(self, *dims):
5656
tf_utils.fast_walsh_hadamard_transform(x)
5757

5858
@parameterized.parameters([[1, 3], [1, 7], [1, 9], [4, 3]])
59-
def test_illegal_inputs_power_of_two(self, *dims):
60-
"""Tests incorrect shape of the rank 2 input."""
59+
def test_illegal_inputs_static_power_of_two(self, *dims):
60+
"""Tests incorrect static shape of the rank 2 input."""
6161
x = tf.random.normal(dims)
6262
with self.assertRaisesRegexp(ValueError,
6363
'The dimension of x must be a power of two.'):
6464
tf_utils.fast_walsh_hadamard_transform(x)
6565

66+
def test_illegal_inputs_dynamic_power_of_two(self):
67+
"""Tests incorrect dynamic shape of the rank 2 input."""
68+
rand = tf.random.uniform((), maxval=3, dtype=tf.int32)
69+
x = tf.random.normal((3, 3**rand))
70+
hx = tf_utils.fast_walsh_hadamard_transform(x)
71+
with self.assertRaisesOpError('The dimension of x must be a power of two.'):
72+
hx = self.evaluate(hx)
73+
74+
@parameterized.parameters([[1, 1], [4, 1], [2, 2], [1, 8], [1, 4]])
75+
def test_static_input_shape(self, *dims):
76+
"""Tests static input shape."""
77+
x = tf.random.normal(dims)
78+
hx_tf = tf_utils.fast_walsh_hadamard_transform(x)
79+
hhx_tf = tf_utils.fast_walsh_hadamard_transform(hx_tf)
80+
81+
x, hx_tf, hhx_tf = self.evaluate([x, hx_tf, hhx_tf])
82+
self.assertAllEqual(x.shape, hhx_tf.shape)
83+
self.assertAllClose(x, hhx_tf)
84+
85+
@parameterized.parameters([[1, 1], [4, 1], [2, 2], [1, 8], [1, 4]])
86+
def test_static_input_output_shape(self, *dims):
87+
"""Tests static output shape is identical to static input shape."""
88+
x = tf.random.normal(dims)
89+
hx_tf = tf_utils.fast_walsh_hadamard_transform(x)
90+
hhx_tf = tf_utils.fast_walsh_hadamard_transform(hx_tf)
91+
self.assertEqual(list(dims), hx_tf.shape.as_list())
92+
self.assertEqual(list(dims), hhx_tf.shape.as_list())
93+
94+
def test_dynamic_input_shape(self):
95+
"""Tests dynamic input shape."""
96+
rand = tf.random.uniform((), maxval=4, dtype=tf.int32)
97+
x = tf.random.normal((3, 2**rand))
98+
hx_tf = tf_utils.fast_walsh_hadamard_transform(x)
99+
hhx_tf = tf_utils.fast_walsh_hadamard_transform(hx_tf)
100+
x, hx_tf, hhx_tf = self.evaluate([x, hx_tf, hhx_tf])
101+
self.assertAllEqual(x.shape, hhx_tf.shape)
102+
self.assertAllClose(x, hhx_tf)
103+
104+
def test_dynamic_input_shape_dim_one(self):
105+
"""Tests input shape where the second dimension is 1, dynamically known."""
106+
rand = tf.random.uniform((), maxval=1, dtype=tf.int32)
107+
x = tf.random.normal((3, 2**rand))
108+
hx_tf = tf_utils.fast_walsh_hadamard_transform(x)
109+
hhx_tf = tf_utils.fast_walsh_hadamard_transform(hx_tf)
110+
x, hx_tf, hhx_tf = self.evaluate([x, hx_tf, hhx_tf])
111+
self.assertAllEqual(x.shape, hhx_tf.shape)
112+
self.assertAllClose(x, hhx_tf)
113+
66114
@parameterized.parameters([2, 4, 8, 16])
67115
def test_output_same_as_simple_python_implementation(self, dim):
68116
"""Tests result is identical to inefficient implementation using scipy."""

0 commit comments

Comments
 (0)