@@ -436,6 +436,73 @@ def test_graph_reg_model_evaluate(self):
436436 self ._train_and_check_eval_results (
437437 train_example , test_example , max_neighbors = 2 , weight = weight , bias = bias )
438438
439+ @test_util .run_v1_only ('Requires tf.GraphKeys' )
440+ def test_graph_reg_wrapper_saving_batch_statistics (self ):
441+ """Verifies that batch statistics in batch-norm layers are saved."""
442+
443+ def optimizer_fn ():
444+ return tf .train .GradientDescentOptimizer (0.005 )
445+
446+ def embedding_fn (features , mode ):
447+ input_layer = features [FEATURE_NAME ]
448+ with tf .compat .v1 .variable_scope ('hidden_layer' , reuse = tf .AUTO_REUSE ):
449+ hidden_layer = tf .compat .v1 .layers .dense (
450+ input_layer , units = 4 , activation = tf .nn .relu )
451+ batch_norm_layer = tf .compat .v1 .layers .batch_normalization (
452+ hidden_layer , training = (mode == tf .estimator .ModeKeys .TRAIN ))
453+ return batch_norm_layer
454+
455+ def model_fn (features , labels , mode , params = None , config = None ):
456+ del params , config
457+ embeddings = embedding_fn (features , mode )
458+ with tf .compat .v1 .variable_scope ('logit' , reuse = tf .AUTO_REUSE ):
459+ logits = tf .compat .v1 .layers .dense (embeddings , units = 1 )
460+ predictions = tf .argmax (logits , 1 )
461+ if mode == tf .estimator .ModeKeys .PREDICT :
462+ return tf .estimator .EstimatorSpec (
463+ mode = mode ,
464+ predictions = {
465+ 'logits' : logits ,
466+ 'predictions' : predictions
467+ })
468+
469+ loss = tf .losses .sigmoid_cross_entropy (labels , logits )
470+ if mode == tf .estimator .ModeKeys .EVAL :
471+ return tf .estimator .EstimatorSpec (mode = mode , loss = loss )
472+
473+ optimizer = optimizer_fn ()
474+ train_op = optimizer .minimize (
475+ loss , global_step = tf .compat .v1 .train .get_global_step ())
476+ update_ops = tf .compat .v1 .get_collection (
477+ tf .compat .v1 .GraphKeys .UPDATE_OPS )
478+ train_op = tf .group (train_op , * update_ops )
479+ return tf .estimator .EstimatorSpec (mode = mode , loss = loss , train_op = train_op )
480+
481+ def input_fn ():
482+ nbr_feature = '{}{}_{}' .format (NBR_FEATURE_PREFIX , 0 , FEATURE_NAME )
483+ nbr_weight = '{}{}{}' .format (NBR_FEATURE_PREFIX , 0 , NBR_WEIGHT_SUFFIX )
484+ features = {
485+ FEATURE_NAME : tf .constant ([[0.1 , 0.9 ], [0.8 , 0.2 ]]),
486+ nbr_feature : tf .constant ([[0.11 , 0.89 ], [0.81 , 0.21 ]]),
487+ nbr_weight : tf .constant ([[0.9 ], [0.8 ]]),
488+ }
489+ labels = tf .constant ([[1 ], [0 ]])
490+ return tf .data .Dataset .from_tensor_slices ((features , labels )).batch (2 )
491+
492+ base_est = tf .estimator .Estimator (model_fn , model_dir = self .model_dir )
493+ graph_reg_config = nsl_configs .make_graph_reg_config (
494+ max_neighbors = 1 , multiplier = 1 )
495+ graph_reg_est = nsl_estimator .add_graph_regularization (
496+ base_est , embedding_fn , optimizer_fn , graph_reg_config = graph_reg_config )
497+ graph_reg_est .train (input_fn , steps = 1 )
498+
499+ moving_mean = graph_reg_est .get_variable_value (
500+ 'hidden_layer/batch_normalization/moving_mean' )
501+ moving_variance = graph_reg_est .get_variable_value (
502+ 'hidden_layer/batch_normalization/moving_variance' )
503+ self .assertNotAllClose (moving_mean , np .zeros (moving_mean .shape ))
504+ self .assertNotAllClose (moving_variance , np .ones (moving_variance .shape ))
505+
439506
440507if __name__ == '__main__' :
441508 tf .test .main ()
0 commit comments