Skip to content

Commit 1f535b0

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Fixes problem when calling the provided model_fn.
While trying out nsl, I discovered a problem with arguments when calling the inner model_fn. I tried to modify my code and adding a unused 'config' argument, but the error just moved to the lines I changed. In fact, if your model accepts 'params', that would have always fail. PiperOrigin-RevId: 292536857
1 parent de59893 commit 1f535b0

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

neural_structured_learning/estimator/adversarial_regularization.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
9191
# If no 'params' is passed, then it is possible for base_model_fn not to
9292
# accept a 'params' argument. See documentation for tf.estimator.Estimator
9393
# for additional context.
94-
if params:
95-
original_spec = base_model_fn(features, labels, mode, params, config)
96-
else:
97-
original_spec = base_model_fn(features, labels, mode, config)
94+
# pylint: disable=g-long-lambda
95+
spec_fn = ((lambda features: base_model_fn(
96+
features, labels, mode, params, config)) if params else (
97+
lambda features: base_model_fn(features, labels, mode, config)))
98+
99+
original_spec = spec_fn(features)
100+
101+
print("ORIGINAL_SPEC", original_spec)
98102

99103
# Adversarial regularization only happens in training.
100104
if mode != tf.estimator.ModeKeys.TRAIN:
@@ -104,7 +108,7 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
104108
adv_config.adv_neighbor_config)
105109

106110
# Runs the base model again to compute loss on adv_neighbor.
107-
adv_spec = base_model_fn(adv_neighbor, labels, mode, config)
111+
adv_spec = spec_fn(adv_neighbor)
108112

109113
final_loss = original_spec.loss + adv_config.multiplier * adv_spec.loss
110114

0 commit comments

Comments
 (0)