1818from __future__ import division
1919from __future__ import print_function
2020
21+ import functools
22+ import inspect
23+
2124import neural_structured_learning .configs as nsl_configs
2225import neural_structured_learning .lib as nsl_lib
23-
2426import tensorflow as tf
2527
2628
@@ -55,6 +57,10 @@ def add_adversarial_regularization(estimator,
5557 adv_config = nsl_configs .AdvRegConfig ()
5658
5759 base_model_fn = estimator ._model_fn # pylint: disable=protected-access
60+ try :
61+ base_model_fn_args = inspect .signature (base_model_fn ).parameters .keys ()
62+ except AttributeError : # For Python 2 compatibility
63+ base_model_fn_args = inspect .getargspec (base_model_fn ).args # pylint: disable=deprecated-method
5864
5965 def adv_model_fn (features , labels , mode , params = None , config = None ):
6066 """The adversarial-regularized model_fn.
@@ -82,19 +88,22 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
8288 Returns:
8389 A `tf.estimator.EstimatorSpec` with adversarial regularization.
8490 """
91+ # Parameters 'params' and 'config' are optional. If they are not passed,
92+ # then it is possible for base_model_fn not to accept these arguments.
93+ # See documentation for tf.estimator.Estimator for additional context.
94+ kwargs = {'mode' : mode }
95+ if 'params' in base_model_fn_args :
96+ kwargs ['params' ] = params
97+ if 'config' in base_model_fn_args :
98+ kwargs ['config' ] = config
99+ base_fn = functools .partial (base_model_fn , ** kwargs )
85100
86101 # Uses the same variable scope for calculating the original objective and
87102 # adversarial regularization.
88103 with tf .compat .v1 .variable_scope (tf .compat .v1 .get_variable_scope (),
89104 reuse = tf .compat .v1 .AUTO_REUSE ,
90105 auxiliary_name_scope = False ):
91- # If no 'params' is passed, then it is possible for base_model_fn not to
92- # accept a 'params' argument. See documentation for tf.estimator.Estimator
93- # for additional context.
94- base_args = [mode , params , config ] if params else [mode , config ]
95- spec_fn = lambda feature , label : base_model_fn (feature , label , * base_args )
96-
97- original_spec = spec_fn (features , labels )
106+ original_spec = base_fn (features , labels )
98107
99108 # Adversarial regularization only happens in training.
100109 if mode != tf .estimator .ModeKeys .TRAIN :
@@ -107,11 +116,11 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
107116 # The pgd_model_fn is a dummy identity function since loss is
108117 # directly available from spec_fn.
109118 pgd_model_fn = lambda features : features ,
110- pgd_loss_fn = lambda labels , features : spec_fn (features , labels ).loss ,
119+ pgd_loss_fn = lambda labels , features : base_fn (features , labels ).loss ,
111120 pgd_labels = labels )
112121
113122 # Runs the base model again to compute loss on adv_neighbor.
114- adv_spec = spec_fn (adv_neighbor , labels )
123+ adv_spec = base_fn (adv_neighbor , labels )
115124
116125 final_loss = original_spec .loss + adv_config .multiplier * adv_spec .loss
117126
0 commit comments