Skip to content

Commit 7058e04

Browse files
csferngtensorflow-copybara
authored andcommitted
Fix batch stat updates for graph-regularized estimator in TensorFlow 1.x
PiperOrigin-RevId: 319880602
1 parent 0bf5398 commit 7058e04

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

neural_structured_learning/estimator/graph_regularization.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,14 @@ def graph_reg_model_fn(features, labels, mode, params=None, config=None):
151151
optimizer = tf.train.AdagradOptimizer(learning_rate=0.05)
152152
else:
153153
optimizer = optimizer_fn()
154-
final_train_op = optimizer.minimize(
154+
train_op = optimizer.minimize(
155155
loss=total_loss, global_step=tf.compat.v1.train.get_global_step())
156+
update_ops = tf.compat.v1.get_collection(
157+
tf.compat.v1.GraphKeys.UPDATE_OPS)
158+
if update_ops:
159+
train_op = tf.group(train_op, *update_ops)
156160

157-
return base_spec._replace(loss=total_loss, train_op=final_train_op)
161+
return base_spec._replace(loss=total_loss, train_op=train_op)
158162

159163
# Replaces the model_fn while keeping other fields/methods in the estimator.
160164
estimator._model_fn = graph_reg_model_fn # pylint: disable=protected-access

neural_structured_learning/estimator/graph_regularization_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,73 @@ def test_graph_reg_model_evaluate(self):
436436
self._train_and_check_eval_results(
437437
train_example, test_example, max_neighbors=2, weight=weight, bias=bias)
438438

439+
@test_util.run_v1_only('Requires tf.GraphKeys')
440+
def test_graph_reg_wrapper_saving_batch_statistics(self):
441+
"""Verifies that batch statistics in batch-norm layers are saved."""
442+
443+
def optimizer_fn():
444+
return tf.train.GradientDescentOptimizer(0.005)
445+
446+
def embedding_fn(features, mode):
447+
input_layer = features[FEATURE_NAME]
448+
with tf.compat.v1.variable_scope('hidden_layer', reuse=tf.AUTO_REUSE):
449+
hidden_layer = tf.compat.v1.layers.dense(
450+
input_layer, units=4, activation=tf.nn.relu)
451+
batch_norm_layer = tf.compat.v1.layers.batch_normalization(
452+
hidden_layer, training=(mode == tf.estimator.ModeKeys.TRAIN))
453+
return batch_norm_layer
454+
455+
def model_fn(features, labels, mode, params=None, config=None):
456+
del params, config
457+
embeddings = embedding_fn(features, mode)
458+
with tf.compat.v1.variable_scope('logit', reuse=tf.AUTO_REUSE):
459+
logits = tf.compat.v1.layers.dense(embeddings, units=1)
460+
predictions = tf.argmax(logits, 1)
461+
if mode == tf.estimator.ModeKeys.PREDICT:
462+
return tf.estimator.EstimatorSpec(
463+
mode=mode,
464+
predictions={
465+
'logits': logits,
466+
'predictions': predictions
467+
})
468+
469+
loss = tf.losses.sigmoid_cross_entropy(labels, logits)
470+
if mode == tf.estimator.ModeKeys.EVAL:
471+
return tf.estimator.EstimatorSpec(mode=mode, loss=loss)
472+
473+
optimizer = optimizer_fn()
474+
train_op = optimizer.minimize(
475+
loss, global_step=tf.compat.v1.train.get_global_step())
476+
update_ops = tf.compat.v1.get_collection(
477+
tf.compat.v1.GraphKeys.UPDATE_OPS)
478+
train_op = tf.group(train_op, *update_ops)
479+
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
480+
481+
def input_fn():
482+
nbr_feature = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, FEATURE_NAME)
483+
nbr_weight = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
484+
features = {
485+
FEATURE_NAME: tf.constant([[0.1, 0.9], [0.8, 0.2]]),
486+
nbr_feature: tf.constant([[0.11, 0.89], [0.81, 0.21]]),
487+
nbr_weight: tf.constant([[0.9], [0.8]]),
488+
}
489+
labels = tf.constant([[1], [0]])
490+
return tf.data.Dataset.from_tensor_slices((features, labels)).batch(2)
491+
492+
base_est = tf.estimator.Estimator(model_fn, model_dir=self.model_dir)
493+
graph_reg_config = nsl_configs.make_graph_reg_config(
494+
max_neighbors=1, multiplier=1)
495+
graph_reg_est = nsl_estimator.add_graph_regularization(
496+
base_est, embedding_fn, optimizer_fn, graph_reg_config=graph_reg_config)
497+
graph_reg_est.train(input_fn, steps=1)
498+
499+
moving_mean = graph_reg_est.get_variable_value(
500+
'hidden_layer/batch_normalization/moving_mean')
501+
moving_variance = graph_reg_est.get_variable_value(
502+
'hidden_layer/batch_normalization/moving_variance')
503+
self.assertNotAllClose(moving_mean, np.zeros(moving_mean.shape))
504+
self.assertNotAllClose(moving_variance, np.ones(moving_variance.shape))
505+
439506

440507
if __name__ == '__main__':
441508
tf.test.main()

0 commit comments

Comments
 (0)