@@ -63,15 +63,13 @@ def _model_fn(features, labels, mode, config):
63
63
dnn_input = combined_dnn_input (sparse_embedding_list , dense_value_list )
64
64
65
65
if len (dnn_hidden_units ) > 0 and cross_num > 0 : # Deep & Cross
66
- deep_out = DNN (dnn_hidden_units , dnn_activation , l2_reg_dnn , dnn_dropout ,
67
- dnn_use_bn , seed )(dnn_input , training = train_flag )
66
+ deep_out = DNN (dnn_hidden_units , dnn_activation , l2_reg_dnn , dnn_dropout , dnn_use_bn , seed = seed )(dnn_input , training = train_flag )
68
67
cross_out = CrossNet (cross_num , l2_reg = l2_reg_cross )(dnn_input )
69
68
stack_out = tf .keras .layers .Concatenate ()([cross_out , deep_out ])
70
69
final_logit = tf .keras .layers .Dense (
71
70
1 , use_bias = False , kernel_initializer = tf .keras .initializers .glorot_normal (seed ))(stack_out )
72
71
elif len (dnn_hidden_units ) > 0 : # Only Deep
73
- deep_out = DNN (dnn_hidden_units , dnn_activation , l2_reg_dnn , dnn_dropout ,
74
- dnn_use_bn , seed )(dnn_input , training = train_flag )
72
+ deep_out = DNN (dnn_hidden_units , dnn_activation , l2_reg_dnn , dnn_dropout , dnn_use_bn , seed = seed )(dnn_input , training = train_flag )
75
73
final_logit = tf .keras .layers .Dense (
76
74
1 , use_bias = False , kernel_initializer = tf .keras .initializers .glorot_normal (seed ))(deep_out )
77
75
elif cross_num > 0 : # Only Cross
0 commit comments