@@ -31,18 +31,15 @@ def testBijector(self):
31
31
for dim in range (2 ):
32
32
if dim == 0 :
33
33
# Test generic case with scalar input
34
- x = tf .ones ((width ,)) * samplers .uniform ((width ,), minval = - 1. ,
35
- maxval = 1. ,
36
- seed = test_util .test_seed (
37
- sampler_type = 'stateless' ))
34
+ x = samplers .uniform ((width ,), minval = - 1. ,
35
+ maxval = 1. ,
36
+ seed = test_util .test_seed (sampler_type = 'stateless' ))
38
37
elif dim == 1 :
39
38
# Test with 2D tensor + batch
40
- x = tf .ones ((5 , width ,
41
- width )) * samplers .uniform ((5 , width , width ),
42
- minval = - 1. ,
43
- maxval = 1. ,
44
- seed = test_util .test_seed (
45
- sampler_type = 'stateless' ))
39
+ x = samplers .uniform ((5 , width , width ),
40
+ minval = - 1. ,
41
+ maxval = 1. ,
42
+ seed = test_util .test_seed (sampler_type = 'stateless' ))
46
43
47
44
bijector = tfp .experimental .bijectors .build_highway_flow_layer (
48
45
width , activation_fn = True )
@@ -58,12 +55,10 @@ def testBijector(self):
58
55
59
56
def testBijectorWithoutActivation (self ):
60
57
width = 4
61
- x = tf .ones (2 , width ,
62
- width ) * samplers .uniform ((2 , width , width ),
63
- minval = - 1. ,
64
- maxval = 1. ,
65
- seed = test_util .test_seed (
66
- sampler_type = 'stateless' ))
58
+ x = samplers .uniform ((2 , width , width ),
59
+ minval = - 1. ,
60
+ maxval = 1. ,
61
+ seed = test_util .test_seed (sampler_type = 'stateless' ))
67
62
68
63
bijector = tfp .experimental .bijectors .build_highway_flow_layer (
69
64
width , activation_fn = False )
@@ -79,11 +74,10 @@ def testBijectorWithoutActivation(self):
79
74
80
75
def testGating (self ):
81
76
width = 4
82
- x = tf .ones ((2 , width ,
83
- width )) * samplers .uniform ((2 , width , width ),
84
- minval = - 1. ,
85
- maxval = 1. , seed = test_util .test_seed (
86
- sampler_type = 'stateless' ))
77
+ x = samplers .uniform ((2 , width , width ),
78
+ minval = - 1. ,
79
+ maxval = 1. ,
80
+ seed = test_util .test_seed (sampler_type = 'stateless' ))
87
81
88
82
# Test with gating half of the inputs
89
83
bijector = tfp .experimental .bijectors .build_highway_flow_layer (
@@ -135,9 +129,9 @@ def testResidualFractionGradientsWithCenteredDifference(self):
135
129
136
130
# pylint: disable=protected-access
137
131
bijector ._residual_fraction = residual_fraction + h
138
- y1 = tf .reduce_mean (target .log_prob (bijector .forward (x )))
132
+ y1 = tf .reduce_mean (target .log_prob (bijector .forward (tf . identity ( x ) )))
139
133
bijector ._residual_fraction = residual_fraction - h
140
- y2 = tf .reduce_mean (target .log_prob (bijector .forward (x )))
134
+ y2 = tf .reduce_mean (target .log_prob (bijector .forward (tf . identity ( x ) )))
141
135
# pylint: enable=protected-access
142
136
143
137
manual_grad = (y1 - y2 ) / (2 * h )
0 commit comments