@@ -56,13 +56,61 @@ def test_illegal_inputs_shape(self, *dims):
56
56
tf_utils .fast_walsh_hadamard_transform (x )
57
57
58
58
@parameterized .parameters ([[1 , 3 ], [1 , 7 ], [1 , 9 ], [4 , 3 ]])
59
- def test_illegal_inputs_power_of_two (self , * dims ):
60
- """Tests incorrect shape of the rank 2 input."""
59
+ def test_illegal_inputs_static_power_of_two (self , * dims ):
60
+ """Tests incorrect static shape of the rank 2 input."""
61
61
x = tf .random .normal (dims )
62
62
with self .assertRaisesRegexp (ValueError ,
63
63
'The dimension of x must be a power of two.' ):
64
64
tf_utils .fast_walsh_hadamard_transform (x )
65
65
66
+ def test_illegal_inputs_dynamic_power_of_two (self ):
67
+ """Tests incorrect dynamic shape of the rank 2 input."""
68
+ rand = tf .random .uniform ((), maxval = 3 , dtype = tf .int32 )
69
+ x = tf .random .normal ((3 , 3 ** rand ))
70
+ hx = tf_utils .fast_walsh_hadamard_transform (x )
71
+ with self .assertRaisesOpError ('The dimension of x must be a power of two.' ):
72
+ hx = self .evaluate (hx )
73
+
74
+ @parameterized .parameters ([[1 , 1 ], [4 , 1 ], [2 , 2 ], [1 , 8 ], [1 , 4 ]])
75
+ def test_static_input_shape (self , * dims ):
76
+ """Tests static input shape."""
77
+ x = tf .random .normal (dims )
78
+ hx_tf = tf_utils .fast_walsh_hadamard_transform (x )
79
+ hhx_tf = tf_utils .fast_walsh_hadamard_transform (hx_tf )
80
+
81
+ x , hx_tf , hhx_tf = self .evaluate ([x , hx_tf , hhx_tf ])
82
+ self .assertAllEqual (x .shape , hhx_tf .shape )
83
+ self .assertAllClose (x , hhx_tf )
84
+
85
+ @parameterized .parameters ([[1 , 1 ], [4 , 1 ], [2 , 2 ], [1 , 8 ], [1 , 4 ]])
86
+ def test_static_input_output_shape (self , * dims ):
87
+ """Tests static output shape is identical to static input shape."""
88
+ x = tf .random .normal (dims )
89
+ hx_tf = tf_utils .fast_walsh_hadamard_transform (x )
90
+ hhx_tf = tf_utils .fast_walsh_hadamard_transform (hx_tf )
91
+ self .assertEqual (list (dims ), hx_tf .shape .as_list ())
92
+ self .assertEqual (list (dims ), hhx_tf .shape .as_list ())
93
+
94
+ def test_dynamic_input_shape (self ):
95
+ """Tests dynamic input shape."""
96
+ rand = tf .random .uniform ((), maxval = 4 , dtype = tf .int32 )
97
+ x = tf .random .normal ((3 , 2 ** rand ))
98
+ hx_tf = tf_utils .fast_walsh_hadamard_transform (x )
99
+ hhx_tf = tf_utils .fast_walsh_hadamard_transform (hx_tf )
100
+ x , hx_tf , hhx_tf = self .evaluate ([x , hx_tf , hhx_tf ])
101
+ self .assertAllEqual (x .shape , hhx_tf .shape )
102
+ self .assertAllClose (x , hhx_tf )
103
+
104
+ def test_dynamic_input_shape_dim_one (self ):
105
+ """Tests input shape where the second dimension is 1, dynamically known."""
106
+ rand = tf .random .uniform ((), maxval = 1 , dtype = tf .int32 )
107
+ x = tf .random .normal ((3 , 2 ** rand ))
108
+ hx_tf = tf_utils .fast_walsh_hadamard_transform (x )
109
+ hhx_tf = tf_utils .fast_walsh_hadamard_transform (hx_tf )
110
+ x , hx_tf , hhx_tf = self .evaluate ([x , hx_tf , hhx_tf ])
111
+ self .assertAllEqual (x .shape , hhx_tf .shape )
112
+ self .assertAllClose (x , hhx_tf )
113
+
66
114
@parameterized .parameters ([2 , 4 , 8 , 16 ])
67
115
def test_output_same_as_simple_python_implementation (self , dim ):
68
116
"""Tests result is identical to inefficient implementation using scipy."""
0 commit comments