@@ -128,13 +128,16 @@ def test_adversarial_wrapper_adds_regularization(self, adv_step_size,
128128
129129 @test_util .run_v1_only ('Requires tf.train.GradientDescentOptimizer' )
130130 def test_adversarial_wrapper_saving_batch_statistics (self ):
131- x0 , y0 = np .array ([[0.9 , 0.1 ], [0.2 , 0.8 ]]), np .array ([1 , 0 ])
131+ x0 = np .array ([[0.9 , 0.1 ], [0.2 , - 0.8 ], [- 0.7 , - 0.3 ], [- 0.4 , 0.6 ]])
132+ y0 = np .array ([1 , 0 , 1 , 0 ])
132133 input_fn = single_batch_input_fn ({FEATURE_NAME : x0 }, y0 )
133134 fc = tf .feature_column .numeric_column (FEATURE_NAME , shape = [2 ])
134135 base_est = tf .estimator .DNNClassifier (
135136 hidden_units = [4 ],
136137 feature_columns = [fc ],
137138 model_dir = self .model_dir ,
139+ activation_fn = lambda x : tf .abs (x ) + 0.1 ,
140+ dropout = None ,
138141 batch_norm = True )
139142 adv_est = nsl_estimator .add_adversarial_regularization (
140143 base_est ,
@@ -145,6 +148,8 @@ def test_adversarial_wrapper_saving_batch_statistics(self):
145148 'dnn/hiddenlayer_0/batchnorm_0/moving_mean' )
146149 moving_variance = adv_est .get_variable_value (
147150 'dnn/hiddenlayer_0/batchnorm_0/moving_variance' )
151+ # The activation function always returns a positive number, so the batch
152+ # mean cannot be zero if updated successfully.
148153 self .assertNotAllClose (moving_mean , np .zeros (moving_mean .shape ))
149154 self .assertNotAllClose (moving_variance , np .ones (moving_variance .shape ))
150155
0 commit comments